def my_reduce()

in optimum/tpu/xla_model_parallel.py [0:0]


def my_reduce(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
    """All-reduce the the input tensor across model parallel group."""
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_

    # All-reduce.
    if USE_CUDA:
        input_ = torch.ops.c10d_functional.all_reduce(input_, "sum", TAG, RANKSET, GROUP_SIZE)
    else:
        input_ = xm.all_reduce(xm.REDUCE_SUM, input_, groups=groups)

    return input_