def split_between_processes()

in src/accelerate/state.py [0:0]


    def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
        """
        Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
        distributed inference, such as with different prompts.

        Note that when using a `dict`, all keys need to have the same number of elements.

        Args:
            inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`):
                The input to split between processes.
            apply_padding (`bool`, `optional`, defaults to `False`):
                Whether to apply padding by repeating the last element of the input so that all processes have the same
                number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
                in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.


        Example:

        ```python
        # Assume there are two processes
        from accelerate import PartialState

        state = PartialState()
        with state.split_between_processes(["A", "B", "C"]) as inputs:
            print(inputs)
        # Process 0
        ["A", "B"]
        # Process 1
        ["C"]

        with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
            print(inputs)
        # Process 0
        ["A", "B"]
        # Process 1
        ["C", "C"]
        ```
        """
        if self.num_processes == 1:
            yield inputs
            return
        length = len(inputs)
        # Nested dictionary of any types
        if isinstance(inputs, dict):
            length = len(inputs[list(inputs.keys())[0]])
            if not all(len(v) == length for v in inputs.values()):
                raise ValueError("All values in the dictionary must have the same length")
        num_samples_per_process, num_extras = divmod(length, self.num_processes)
        start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras)
        end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0)

        def _split_values(inputs, start_index, end_index):
            if isinstance(inputs, (list, tuple, torch.Tensor)):
                if start_index >= len(inputs):
                    result = inputs[-1:]
                else:
                    result = inputs[start_index:end_index]
                if apply_padding:
                    if isinstance(result, torch.Tensor):
                        from accelerate.utils import pad_across_processes, send_to_device

                        # The tensor needs to be on the device before we can pad it
                        tensorized_result = send_to_device(result, self.device)
                        result = pad_across_processes(tensorized_result, pad_index=inputs[-1])
                    else:
                        result += [result[-1]] * (num_samples_per_process + (1 if num_extras > 0 else 0) - len(result))
                return result
            elif isinstance(inputs, dict):
                for key in inputs.keys():
                    inputs[key] = _split_values(inputs[key], start_index, end_index)
                return inputs
            else:
                if is_datasets_available():
                    from datasets import Dataset

                    if isinstance(inputs, Dataset):
                        if start_index >= len(inputs):
                            start_index = len(inputs) - 1
                        if end_index > len(inputs):
                            end_index = len(inputs)
                        result_idcs = list(range(start_index, end_index))
                        if apply_padding:
                            result_idcs += [end_index - 1] * (
                                num_samples_per_process + (1 if num_extras > 0 else 0) - len(result_idcs)
                            )
                        return inputs.select(result_idcs)
                return inputs

        yield _split_values(inputs, start_index, end_index)