in src/accelerate/state.py [0:0]
def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
"""
Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
distributed inference, such as with different prompts.
Note that when using a `dict`, all keys need to have the same number of elements.
Args:
inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`):
The input to split between processes.
apply_padding (`bool`, `optional`, defaults to `False`):
Whether to apply padding by repeating the last element of the input so that all processes have the same
number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
Example:
```python
# Assume there are two processes
from accelerate import PartialState
state = PartialState()
with state.split_between_processes(["A", "B", "C"]) as inputs:
print(inputs)
# Process 0
["A", "B"]
# Process 1
["C"]
with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
print(inputs)
# Process 0
["A", "B"]
# Process 1
["C", "C"]
```
"""
if self.num_processes == 1:
yield inputs
return
length = len(inputs)
# Nested dictionary of any types
if isinstance(inputs, dict):
length = len(inputs[list(inputs.keys())[0]])
if not all(len(v) == length for v in inputs.values()):
raise ValueError("All values in the dictionary must have the same length")
num_samples_per_process, num_extras = divmod(length, self.num_processes)
start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras)
end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0)
def _split_values(inputs, start_index, end_index):
if isinstance(inputs, (list, tuple, torch.Tensor)):
if start_index >= len(inputs):
result = inputs[-1:]
else:
result = inputs[start_index:end_index]
if apply_padding:
if isinstance(result, torch.Tensor):
from accelerate.utils import pad_across_processes, send_to_device
# The tensor needs to be on the device before we can pad it
tensorized_result = send_to_device(result, self.device)
result = pad_across_processes(tensorized_result, pad_index=inputs[-1])
else:
result += [result[-1]] * (num_samples_per_process + (1 if num_extras > 0 else 0) - len(result))
return result
elif isinstance(inputs, dict):
for key in inputs.keys():
inputs[key] = _split_values(inputs[key], start_index, end_index)
return inputs
else:
if is_datasets_available():
from datasets import Dataset
if isinstance(inputs, Dataset):
if start_index >= len(inputs):
start_index = len(inputs) - 1
if end_index > len(inputs):
end_index = len(inputs)
result_idcs = list(range(start_index, end_index))
if apply_padding:
result_idcs += [end_index - 1] * (
num_samples_per_process + (1 if num_extras > 0 else 0) - len(result_idcs)
)
return inputs.select(result_idcs)
return inputs
yield _split_values(inputs, start_index, end_index)