in src/accelerate/data_loader.py [0:0]
def skip_first_batches(dataloader, num_batches=0):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
the original dataloader is a `StatefulDataLoader`.
"""
state = PartialState()
if state.distributed_type == DistributedType.XLA:
device = dataloader.device
dataloader = dataloader.dataloader
dataset = dataloader.dataset
sampler_is_batch_sampler = False
if isinstance(dataset, IterableDataset):
new_batch_sampler = None
else:
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
# We ignore all of those since they are all dealt with by our new_batch_sampler
ignore_kwargs = [
"batch_size",
"shuffle",
"sampler",
"batch_sampler",
"drop_last",
]
kwargs = {
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
for k in _PYTORCH_DATALOADER_KWARGS
if k not in ignore_kwargs
}
# Need to provide batch_size as batch_sampler is None for Iterable dataset
if new_batch_sampler is None:
kwargs["drop_last"] = dataloader.drop_last
kwargs["batch_size"] = dataloader.batch_size
if isinstance(dataloader, DataLoaderDispatcher):
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
kwargs["skip_batches"] = num_batches
dataloader = DataLoaderDispatcher(
dataset,
split_batches=dataloader.split_batches,
batch_sampler=new_batch_sampler,
_drop_last=dataloader._drop_last,
**kwargs,
)
elif isinstance(dataloader, DataLoaderShard):
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
kwargs["skip_batches"] = num_batches
elif sampler_is_batch_sampler:
kwargs["sampler"] = new_batch_sampler
kwargs["batch_size"] = dataloader.batch_size
else:
kwargs["batch_sampler"] = new_batch_sampler
dataloader = DataLoaderShard(
dataset,
device=dataloader.device,
rng_types=dataloader.rng_types,
synchronized_generator=dataloader.synchronized_generator,
**kwargs,
)
else:
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
else:
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
if state.distributed_type == DistributedType.XLA:
dataloader = MpDeviceLoaderWrapper(dataloader, device)
return dataloader