in modules/SwissArmyTransformer/sat/data_utils/configure_data.py [0:0]
def make_data_loader(dataset, batch_size, args, split, collate_fn=None):
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
distributed = world_size > 1
# if IterableDataset, assume everything is properly configured. (pre-sharded)
if isinstance(dataset, IterableDataset):
if split in ['val', 'test'] and args.strict_eval:
raise ValueError('IterableDataset cannot be used for validation or testing if `args.strict_eval=True`, because we cannot infer the length of the final batch before reading out them.')
args.val_last_shape = [1] * world_size # just fake it, not actually used
args.val_drop_number = 0
args.test_last_shape = [1] * world_size
args.test_drop_number = 0
per_rank_batch_size = None if args.iterable_dataset == 'custom' else batch_size//world_size
return torch.utils.data.DataLoader(
dataset,
batch_size=per_rank_batch_size,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_fn,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
timeout=10
)
sampler = torch.utils.data.SequentialSampler(dataset)
drop_last = False # COMMENT: this is already solved by the complex logic of last_shape and drop_number.
# the GPUs in the same model parallel group receive the same data
if distributed: # TODO reformat this, but it is not urgent
gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
batch_sampler = DistributedBatchSampler(sampler,
batch_size,
drop_last,
rank,
world_size,
gradient_accumulation_steps=gradient_accumulation_steps)
else:
batch_sampler = torch.utils.data.BatchSampler(sampler,
batch_size,
drop_last)
last_len = len(dataset) % batch_size
batch_per_worker = batch_size // world_size
last_shape = [batch_per_worker] * (last_len//batch_per_worker) # some processes get full batch
if last_len != 0:
if last_len % batch_per_worker != 0:
last_shape.append(last_len % batch_per_worker) # one process get the rest (<1 batch)
drop_number = world_size - ((last_len-1)//batch_per_worker + 1)
# other processes get nothing, but append 1 for running. will drop later according to drop_number.
for j in range(drop_number):
last_shape.append(1)
else:
drop_number = 0
if split=='val':
args.val_last_shape = last_shape
args.val_drop_number = drop_number
elif split=='test':
args.test_last_shape = last_shape
args.test_drop_number = drop_number
data_loader = torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_fn,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
)
return data_loader