def my_gather()

in optimum/tpu/xla_model_parallel.py [0:0]


def my_gather(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
    """Gather tensors and concatinate along the last dimension."""
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_

    if USE_CUDA:
        last_dim = input_.dim() - 1

        # Using all_reduce to achieve all_gather as torch.ops.c10d_functional.all_gather_into_tensor
        # is buggy in 16 bits.
        size = input_.size(last_dim)
        padding = [0] * (2 * input_.dim())
        ordinal = rank
        left, right = ordinal, world_size - 1 - ordinal
        idx = input_.dim() - 1 - last_dim
        padding[2 * idx] = left * size
        padding[2 * idx + 1] = right * size
        output = torch.ops.c10d_functional.all_reduce(F.pad(input_, padding), "sum", TAG, RANKSET, GROUP_SIZE)
    else:
        output = xm.all_gather(input_, dim=-1, groups=groups)

    return output