in modules/SwissArmyTransformer/sat/data_utils/configure_data.py [0:0]
def make_loaders(args, create_dataset_function, collate_fn=None):
"""makes training/val/test
Args:
args.train_data, args.valid_data, args.test_data: str. Paths to the dataset.
args.split: str. format: "8,1,1". how to split train_data.
args.dataset_type: use to create the right datasets.
"""
make_dataset = partial(make_dataset_full,
create_dataset_function=create_dataset_function, batch_from_same_dataset=args.batch_from_same_dataset)
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
batch_size = args.batch_size * world_size
eval_batch_size = batch_size
if args.eval_batch_size is not None:
eval_batch_size = args.eval_batch_size * world_size
split = get_split(args)
data_set_args = {
'path': args.train_data,
'split': split,
}
eval_set_args = copy.copy(data_set_args)
eval_set_args['split'] = [1.]
# make datasets splits and tokenizer
train = None
valid = None
test = None
if args.train_data is not None:
train = make_dataset(**data_set_args, args=args, dataset_weights=args.train_data_weights, is_train_data=True)
if should_split(split):
train, valid, test = train
# make training and val dataset if necessary
if valid is None and args.valid_data is not None:
eval_set_args['path'] = args.valid_data
valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
if test is None and args.test_data is not None:
eval_set_args['path'] = args.test_data
test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
# wrap datasets with data loader
if train is not None and args.batch_size > 0:
train = make_data_loader(train, batch_size, args, split='train', collate_fn=collate_fn)
args.do_train = True
else:
args.do_train = False
eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
if valid is not None:
valid = make_data_loader(valid, eval_batch_size, args, split='val', collate_fn=collate_fn)
args.do_valid = True
else:
args.do_valid = False
if test is not None:
test = make_data_loader(test, eval_batch_size, args, split='test', collate_fn=collate_fn)
args.do_test = True
else:
args.do_test = False
return train, valid, test