in src/accelerate/utils/megatron_lm.py [0:0]
def prepare_data_loader(accelerator, dataloader):
accelerator.print("Preparing dataloader")
args = get_args()
if not args.megatron_dataset_flag:
from ..data_loader import _PYTORCH_DATALOADER_KWARGS, prepare_data_loader
micro_batch_size = args.micro_batch_size * args.num_micro_batches
kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS}
if kwargs["batch_size"] is None:
if isinstance(kwargs["sampler"], torch.utils.data.BatchSampler):
kwargs["sampler"].batch_size = micro_batch_size
else:
del kwargs["sampler"]
del kwargs["shuffle"]
del kwargs["batch_size"]
kwargs["batch_sampler"].batch_size = micro_batch_size
else:
del kwargs["batch_sampler"]
kwargs["batch_size"] = micro_batch_size
dataloader = torch.utils.data.DataLoader(dataloader.dataset, **kwargs)
# split_batches:
# Megatron only needs to fetch different data between different dp groups,
# and does not need to split the data within the dp group.
return prepare_data_loader(
dataloader,
accelerator.device,
num_processes=mpu.get_data_parallel_world_size(),
process_index=mpu.get_data_parallel_rank(),
split_batches=False,
put_on_device=True,
rng_types=accelerator.rng_types.copy(),
dispatch_batches=accelerator.dispatch_batches,
)
else:
if args.consumed_samples is not None:
(
args.consumed_train_samples,
args.consumed_valid_samples,
args.consumed_test_samples,
) = args.consumed_samples
else:
args.consumed_train_samples, args.consumed_valid_samples, args.consumed_test_samples = 0, 0, 0
args.micro_batch_size = args.micro_batch_size * args.num_micro_batches
# In order to be compatible with data in transform format,
# it needs to increase the size of mbs first,
# and then split the large batch data into some mbs.
(
train_data_iterator,
valid_data_iterator,
test_data_iterator,
) = dataloader.build_train_valid_test_data_iterators(accelerator)
args.micro_batch_size = args.micro_batch_size // args.num_micro_batches
train_data_iterator = _handle_megatron_data_iterator(
accelerator=accelerator, data_iterator=train_data_iterator
)
valid_data_iterator = _handle_megatron_data_iterator(
accelerator=accelerator, data_iterator=valid_data_iterator
)
test_data_iterator = _handle_megatron_data_iterator(accelerator=accelerator, data_iterator=test_data_iterator)
return train_data_iterator, valid_data_iterator, test_data_iterator