def _fetch_batches()

in src/accelerate/data_loader.py [0:0]


    def _fetch_batches(self, iterator):
        batches, batch = None, None
        # On process 0, we gather the batch to dispatch.
        if self.state.process_index == 0:
            # Procedure to support TP only is simpler
            # since we want to dispatch the same batch of samples across all ranks
            # this removes complexity of handling multiple tp rank groups when TP + DP
            # combination is involved.

            try:
                # for TP case avoid using split_batches
                # since it would mean that the dataloader should be spilling out
                # duplicates of batches.
                if self.split_batches:
                    # One batch of the main iterator is dispatched and split.
                    if self.submesh_tp:
                        logger.warning(
                            "Use of split_batches for TP would need the dataloader to produce duplicate batches,"
                            "otherwise, use dispatch_batches=True instead."
                        )
                    self._update_state_dict()
                    batch = next(iterator)
                else:
                    # num_processes batches of the main iterator are concatenated then dispatched and split.
                    # We add the batches one by one so we have the remainder available when drop_last=False.
                    batches = []
                    if self.submesh_tp:
                        # when tp, extract single batch and then replicate
                        self._update_state_dict()
                        batch = next(iterator)
                        batches = [batch] * self.state.num_processes
                    else:
                        for _ in range(self.state.num_processes):
                            self._update_state_dict()
                            batches.append(next(iterator))
                    try:
                        batch = concatenate(batches, dim=0)
                    except RuntimeError as e:
                        raise RuntimeError(
                            "You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`."
                            "either pass `dispatch_batches=False` and have each process fetch its own batch "
                            " or pass `split_batches=True`. By doing so, the main process will fetch a full batch and "
                            "slice it into `num_processes` batches for each process."
                        ) from e
                # In both cases, we need to get the structure of the batch that we will broadcast on other
                # processes to initialize the tensors with the right shape.
                # data_structure, stop_iteration
                batch_info = [get_data_structure(batch), False]
            except StopIteration:
                batch_info = [None, True]
        else:
            batch_info = [None, self._stop_iteration]
        # This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
        broadcast_object_list(batch_info)
        self._stop_iteration = batch_info[1]
        if self._stop_iteration:
            # If drop_last is False and split_batches is False, we may have a remainder to take care of.
            if not self.split_batches and not self._drop_last:
                if self.state.process_index == 0 and len(batches) > 0:
                    batch = concatenate(batches, dim=0)
                    batch_info = [get_data_structure(batch), False]
                else:
                    batch_info = [None, True]
                broadcast_object_list(batch_info)
        return batch, batch_info