Skip to content

Commit

Permalink
Update loader.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga authored May 29, 2024
1 parent 51dd454 commit b55fb61
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions src/llamafactory/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import inspect
import os
import numpy as np
from numpy.random import RandomState
import sys
from typing import TYPE_CHECKING, Literal, Optional, Union

import numpy as np
from datasets import load_dataset, load_from_disk

from ..extras.constants import FILEEXT2TYPE
Expand Down Expand Up @@ -108,20 +107,14 @@ def load_single_dataset(
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter

if dataset_attr.num_samples is not None and not data_args.streaming:
indexes = np.random.permutation(len(dataset))[: dataset_attr.num_samples]
dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))

if data_args.max_samples is not None: # truncate dataset
num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples))

if dataset_attr.sample_num:
dataset_sample_num = dataset_attr.sample_num
logger.info(f"从 {dataset_attr.dataset_name} 采样 {dataset_sample_num} 条训练样本")
random_state = RandomState(42)
idx = random_state.permutation(len(dataset))[:dataset_sample_num]
dataset_sample_num -= len(idx)
if dataset_sample_num > 0:
idx2 = random_state.choice(len(dataset), dataset_sample_num)
idx = np.concatenate([idx, idx2], axis=0)
dataset = dataset.select(idx)
indexes = np.random.permutation(len(dataset))[: data_args.max_samples]
dataset = dataset.select(indexes)

return align_dataset(dataset, dataset_attr, data_args)

Expand Down

0 comments on commit b55fb61

Please sign in to comment.