in src/accelerate/data_loader.py [0:0]
def _iter_with_no_split(self):
initial_data = []
batch_to_yield = []
for idx, batch in enumerate(self.batch_sampler):
# We gather the initial indices in case we need to circle back at the end.
if not self.drop_last and idx < self.num_processes:
initial_data += batch
# We identify the batch to yield but wait until we ar sure every process gets a full batch before actually
# yielding it.
if idx % self.num_processes == self.process_index:
batch_to_yield = batch
if idx % self.num_processes == self.num_processes - 1 and (
self.batch_size is None or len(batch) == self.batch_size
):
yield batch_to_yield
batch_to_yield = []
# If drop_last is True, iteration is over, otherwise...
if not self.drop_last and len(initial_data) > 0:
if not self.even_batches:
if len(batch_to_yield) > 0:
yield batch_to_yield
else:
# ... we yield the complete batch we had saved before if it has the proper length
if len(batch_to_yield) == self.batch_size:
yield batch_to_yield
# For degenerate cases where the dataset has less than num_process * batch_size samples
while len(initial_data) < self.num_processes * self.batch_size:
initial_data += initial_data
# If the last batch seen was of the proper size, it has been yielded by its process so we move to the next
if len(batch) == self.batch_size:
batch = []
idx += 1
# Make sure we yield a multiple of self.num_processes batches
cycle_index = 0
while idx % self.num_processes != 0 or len(batch) > 0:
end_index = cycle_index + self.batch_size - len(batch)
batch += initial_data[cycle_index:end_index]
if idx % self.num_processes == self.process_index:
yield batch
cycle_index = end_index
batch = []
idx += 1