in src/accelerate/data_loader.py [0:0]
def __iter__(self):
self.begin()
self.set_epoch(self.iteration)
main_iterator = None
if is_torch_version(">=", "2.0.1"):
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
# But, we only iterate through the DataLoader on process 0.
main_iterator = self.base_dataloader.__iter__()
elif self.state.process_index == 0:
main_iterator = self.base_dataloader.__iter__()
stop_iteration = False
self._stop_iteration = False
first_batch = None
next_batch, next_batch_info = self._fetch_batches(main_iterator)
batch_index = 0
while not stop_iteration:
batch, batch_info = next_batch, next_batch_info
if self.state.process_index != 0:
# Initialize tensors on other processes than process 0.
batch = initialize_tensors(batch_info[0])
batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking)
# Broadcast the batch before splitting it.
batch = broadcast(batch, from_process=0)
if not self._drop_last and first_batch is None:
# We keep at least num processes elements of the first batch to be able to complete the last batch
first_batch = self.slice_fn(
batch,
slice(0, self.state.num_processes),
process_index=self.state.process_index,
num_processes=self.state.num_processes,
)
if batch is None:
raise ValueError(
f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration."
)
observed_batch_size = find_batch_size(batch)
batch_size = observed_batch_size // self.state.num_processes
stop_iteration = self._stop_iteration
if not stop_iteration:
# We may still be at the end of the dataloader without knowing it yet: if there is nothing left in
# the dataloader since the number of batches is a round multiple of the number of processes.
next_batch, next_batch_info = self._fetch_batches(main_iterator)
# next_batch_info[0] is None when there are no more batches, otherwise we still need to process them.
if self._stop_iteration and next_batch_info[0] is None:
stop_iteration = True
if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0:
# If the last batch is not complete, let's add the first batch to it.
batch = concatenate([batch, first_batch], dim=0)
# Batch size computation above is wrong, it's off by 1 so we fix it.
batch_size += 1
data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
batch = self.slice_fn(
batch,
data_slice,
process_index=self.state.process_index,
num_processes=self.state.num_processes,
)
if stop_iteration:
self.end_of_dataloader = True
self._update_state_dict()
self.remainder = observed_batch_size
if batch_index >= self.skip_batches:
yield batch
batch_index += 1
self.iteration += 1
self.end()