in src/accelerate/data_loader.py [0:0]
def _fetch_batches(self, iterator):
batches, batch = None, None
# On process 0, we gather the batch to dispatch.
if self.state.process_index == 0:
# Procedure to support TP only is simpler
# since we want to dispatch the same batch of samples across all ranks
# this removes complexity of handling multiple tp rank groups when TP + DP
# combination is involved.
try:
# for TP case avoid using split_batches
# since it would mean that the dataloader should be spilling out
# duplicates of batches.
if self.split_batches:
# One batch of the main iterator is dispatched and split.
if self.submesh_tp:
logger.warning(
"Use of split_batches for TP would need the dataloader to produce duplicate batches,"
"otherwise, use dispatch_batches=True instead."
)
self._update_state_dict()
batch = next(iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
if self.submesh_tp:
# when tp, extract single batch and then replicate
self._update_state_dict()
batch = next(iterator)
batches = [batch] * self.state.num_processes
else:
for _ in range(self.state.num_processes):
self._update_state_dict()
batches.append(next(iterator))
try:
batch = concatenate(batches, dim=0)
except RuntimeError as e:
raise RuntimeError(
"You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`."
"either pass `dispatch_batches=False` and have each process fetch its own batch "
" or pass `split_batches=True`. By doing so, the main process will fetch a full batch and "
"slice it into `num_processes` batches for each process."
) from e
# In both cases, we need to get the structure of the batch that we will broadcast on other
# processes to initialize the tensors with the right shape.
# data_structure, stop_iteration
batch_info = [get_data_structure(batch), False]
except StopIteration:
batch_info = [None, True]
else:
batch_info = [None, self._stop_iteration]
# This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
broadcast_object_list(batch_info)
self._stop_iteration = batch_info[1]
if self._stop_iteration:
# If drop_last is False and split_batches is False, we may have a remainder to take care of.
if not self.split_batches and not self._drop_last:
if self.state.process_index == 0 and len(batches) > 0:
batch = concatenate(batches, dim=0)
batch_info = [get_data_structure(batch), False]
else:
batch_info = [None, True]
broadcast_object_list(batch_info)
return batch, batch_info