in dualpipe/utils.py [0:0]
def scatter(inputs, chunks, dim):
assert isinstance(inputs, (torch.Tensor, tuple, list))
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
assert all(x is None or isinstance(x, torch.Tensor) for x in inputs)
inputs = [chunk_tensor(x, chunks, dim) for x in inputs]
microbatches = [microbatch for microbatch in zip(*inputs)]
if len(microbatches) == 0:
microbatches = [() for _ in range(chunks)]
return microbatches