in optimum/tpu/xla_model_parallel.py [0:0]
def my_gather(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
"""Gather tensors and concatinate along the last dimension."""
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
if USE_CUDA:
last_dim = input_.dim() - 1
# Using all_reduce to achieve all_gather as torch.ops.c10d_functional.all_gather_into_tensor
# is buggy in 16 bits.
size = input_.size(last_dim)
padding = [0] * (2 * input_.dim())
ordinal = rank
left, right = ordinal, world_size - 1 - ordinal
idx = input_.dim() - 1 - last_dim
padding[2 * idx] = left * size
padding[2 * idx + 1] = right * size
output = torch.ops.c10d_functional.all_reduce(F.pad(input_, padding), "sum", TAG, RANKSET, GROUP_SIZE)
else:
output = xm.all_gather(input_, dim=-1, groups=groups)
return output