in modules/SwissArmyTransformer/sat/data_utils/configure_data.py [0:0]
def make_dataset_full(path, split, args, create_dataset_function,
dataset_weights=None, random_mapping=True, is_train_data=False, batch_from_same_dataset=False, **kwargs):
"""function to create datasets+tokenizers for common options"""
print_all('make dataset ' + str(path), level='DEBUG')
assert isinstance(path, list)
if (is_train_data and args.iterable_dataset) or (not is_train_data and args.iterable_dataset_eval): # cannot indexed
# the random mapping is flexible and efficient, but sometimes we have pratical issue
# For instance, someone just gives you a iterable dataset, e.g. webdataset
from .webds import ConfiguredResampledShards, DataPipeline
valid_types = (ConfiguredResampledShards, DataPipeline)
assert split[0] == 1, 'Iterable dataset cannot auto split.'
ds = []
for p in path:
d = create_dataset_function(p, args)
assert isinstance(d, valid_types)
ds.append(d)
# ds = ChainDataset(ds) # please merge them in a url if chain
if batch_from_same_dataset:
assert args.num_workers <= 1, 'We cannot control the actual speed of different workers, may mix different iterable parts.'
ds = AlterDataset(ds, weights=dataset_weights, seed=args.seed, batch_from_same_dataset=batch_from_same_dataset, batch_size=args.batch_size)
return ds
if split is None:
split = [1.]
if not should_split(split):
ds = []
for p in path:
d = create_dataset_function(p, args)
ds.append(d)
ds = ConcatDataset(ds, weights=dataset_weights)
if random_mapping:
if args.epochs is not None: # not auto-scale, but use a given number of epoches.
ds = RandomDataset(ds, scale=args.epochs, seed=args.seed)
else:
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
if is_train_data:
# only train-dataset will set this to True,
# so we enlarge it to make sure that the data is sufficient.
scale = max(200, 1 + (args.train_iters * args.batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
else:
scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
ds = RandomMappingDataset(ds, scale=scale)
return ds
else:
# must first split datasets, then reweight/concat, finally random-mapping.
# this order avoids overlapping.
train_ds, valid_ds, test_ds = [], [], []
for p in path:
d = create_dataset_function(p, args)
if should_split(split):
dtrain, dvalid, dtest = split_ds(d, split, block_size=args.block_size, seed=args.seed)
train_ds.append(dtrain)
valid_ds.append(dvalid)
test_ds.append(dtest)
train_ds = ConcatDataset(train_ds, weights=dataset_weights)
valid_ds = ConcatDataset(valid_ds, weights=dataset_weights)
test_ds = ConcatDataset(test_ds, weights=dataset_weights)
if random_mapping:
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
scale = max(200, 1 + (args.train_iters * args.batch_size * world_size) // len(train_ds))
train_ds = RandomMappingDataset(train_ds, scale=scale)
scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(valid_ds))
valid_ds = RandomMappingDataset(valid_ds, scale=scale)
test_ds = RandomMappingDataset(test_ds)
return train_ds, valid_ds, test_ds