in src/nanotron/data/processing.py [0:0]
def _get_dataset_mix(dataset_dict: dict, splits: List[str] = None, seed=42) -> "DatasetDict":
"""
Helper function to load dataset mix from dict configuration.
Args:
dataset_dict: Dictionary containing the dataset names and their training proportions.
By default, all test proportions are 1.
splits: Section of the dataset to load, defaults to ["train", "test"]
seed: Random seed for shuffling datasets
Returns:
DatasetDict containing the mixed datasets
"""
raw_datasets = DatasetDict()
raw_train_datasets = []
raw_test_datasets = []
fracs = []
for ds, frac in dataset_dict.items():
if frac < 0:
raise ValueError(f"Dataset fraction for dataset {ds} is negative. (= {frac})")
fracs.append(frac)
for split in splits:
if "train" in split:
raw_train_datasets.append(
load_dataset(
ds,
split=split,
)
)
elif "test" in split:
raw_test_datasets.append(
load_dataset(
ds,
split=split,
)
)
else:
raise ValueError(f"Split type {split} not recognized as one of test or train.")
if len(raw_train_datasets) > 0:
train_subsets = []
for dataset, frac in zip(raw_train_datasets, fracs):
train_subset = dataset.select(range(int(frac * len(dataset))))
train_subsets.append(train_subset)
raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=seed)
# No subsampling for test datasets to enable fair comparison across models
if len(raw_test_datasets) > 0:
raw_datasets["test"] = concatenate_datasets(raw_test_datasets).shuffle(seed=seed)
if len(raw_datasets) == 0:
raise ValueError(
f"Dataset {dataset_dict} not recognized with split {splits}. Check the dataset has been correctly formatted."
)
return raw_datasets