def make_data_loader()

in modules/SwissArmyTransformer/sat/data_utils/configure_data.py [0:0]


def make_data_loader(dataset, batch_size, args, split, collate_fn=None):

    world_size = torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())
    rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
    distributed = world_size > 1

    # if IterableDataset, assume everything is properly configured. (pre-sharded) 
    if isinstance(dataset, IterableDataset):
        if split in ['val', 'test'] and args.strict_eval:
            raise ValueError('IterableDataset cannot be used for validation or testing if `args.strict_eval=True`, because we cannot infer the length of the final batch before reading out them.')
        args.val_last_shape = [1] * world_size # just fake it, not actually used
        args.val_drop_number = 0
        args.test_last_shape = [1] * world_size
        args.test_drop_number = 0
        per_rank_batch_size = None if args.iterable_dataset == 'custom' else batch_size//world_size
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=per_rank_batch_size,
            num_workers=args.num_workers,
            pin_memory=True,
            collate_fn=collate_fn,
            prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
            timeout=10
            )

    sampler = torch.utils.data.SequentialSampler(dataset)

    drop_last = False # COMMENT: this is already solved by the complex logic of last_shape and drop_number.

    # the GPUs in the same model parallel group receive the same data
    if distributed: # TODO reformat this, but it is not urgent
        gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
        batch_sampler = DistributedBatchSampler(sampler,
                                                batch_size,
                                                drop_last,
                                                rank,
                                                world_size,
                                                gradient_accumulation_steps=gradient_accumulation_steps)
    else:
        batch_sampler = torch.utils.data.BatchSampler(sampler,
                                                      batch_size,
                                                      drop_last)
    last_len = len(dataset) % batch_size
    batch_per_worker = batch_size // world_size
    last_shape = [batch_per_worker] * (last_len//batch_per_worker) # some processes get full batch
    if last_len != 0:
        if last_len % batch_per_worker != 0:
            last_shape.append(last_len % batch_per_worker) # one process get the rest (<1 batch)
        drop_number = world_size - ((last_len-1)//batch_per_worker + 1)
        # other processes get nothing, but append 1 for running. will drop later according to drop_number.
        for j in range(drop_number): 
            last_shape.append(1)
    else:
        drop_number = 0
    if split=='val':
        args.val_last_shape = last_shape
        args.val_drop_number = drop_number
    elif split=='test':
        args.test_last_shape = last_shape
        args.test_drop_number = drop_number
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_sampler=batch_sampler,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              collate_fn=collate_fn,
                                              prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
                                              )
    return data_loader