in src/accelerate/data_loader.py [0:0]
def __iter__(self):
if (
not hasattr(self.dataset, "set_epoch")
and hasattr(self.dataset, "generator")
and isinstance(self.dataset.generator, torch.Generator)
):
self.dataset.generator.manual_seed(self.epoch)
real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
first_batch = None
current_batch = []
for element in self.dataset:
current_batch.append(element)
# Wait to have a full batch before yielding elements.
if len(current_batch) == real_batch_size:
for i in process_slice:
yield current_batch[i]
if first_batch is None:
first_batch = current_batch.copy()
current_batch = []
# Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
if not self.drop_last and len(current_batch) > 0:
if first_batch is None:
first_batch = current_batch.copy()
while len(current_batch) < real_batch_size:
current_batch += first_batch
for i in process_slice:
yield current_batch[i]