def prepare_data_loader()

in src/accelerate/utils/megatron_lm.py [0:0]


def prepare_data_loader(accelerator, dataloader):
    accelerator.print("Preparing dataloader")
    args = get_args()
    if not args.megatron_dataset_flag:
        from ..data_loader import _PYTORCH_DATALOADER_KWARGS, prepare_data_loader

        micro_batch_size = args.micro_batch_size * args.num_micro_batches
        kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS}
        if kwargs["batch_size"] is None:
            if isinstance(kwargs["sampler"], torch.utils.data.BatchSampler):
                kwargs["sampler"].batch_size = micro_batch_size
            else:
                del kwargs["sampler"]
                del kwargs["shuffle"]
                del kwargs["batch_size"]
                kwargs["batch_sampler"].batch_size = micro_batch_size
        else:
            del kwargs["batch_sampler"]
            kwargs["batch_size"] = micro_batch_size

        dataloader = torch.utils.data.DataLoader(dataloader.dataset, **kwargs)
        # split_batches:
        # Megatron only needs to fetch different data between different dp groups,
        # and does not need to split the data within the dp group.
        return prepare_data_loader(
            dataloader,
            accelerator.device,
            num_processes=mpu.get_data_parallel_world_size(),
            process_index=mpu.get_data_parallel_rank(),
            split_batches=False,
            put_on_device=True,
            rng_types=accelerator.rng_types.copy(),
            dispatch_batches=accelerator.dispatch_batches,
        )
    else:
        if args.consumed_samples is not None:
            (
                args.consumed_train_samples,
                args.consumed_valid_samples,
                args.consumed_test_samples,
            ) = args.consumed_samples
        else:
            args.consumed_train_samples, args.consumed_valid_samples, args.consumed_test_samples = 0, 0, 0
        args.micro_batch_size = args.micro_batch_size * args.num_micro_batches
        # In order to be compatible with data in transform format,
        # it needs to increase the size of mbs first,
        # and then split the large batch data into some mbs.
        (
            train_data_iterator,
            valid_data_iterator,
            test_data_iterator,
        ) = dataloader.build_train_valid_test_data_iterators(accelerator)
        args.micro_batch_size = args.micro_batch_size // args.num_micro_batches

        train_data_iterator = _handle_megatron_data_iterator(
            accelerator=accelerator, data_iterator=train_data_iterator
        )
        valid_data_iterator = _handle_megatron_data_iterator(
            accelerator=accelerator, data_iterator=valid_data_iterator
        )
        test_data_iterator = _handle_megatron_data_iterator(accelerator=accelerator, data_iterator=test_data_iterator)

        return train_data_iterator, valid_data_iterator, test_data_iterator