in optimum/tpu/xla_model_parallel.py [0:0]
def my_reduce(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# All-reduce.
if USE_CUDA:
input_ = torch.ops.c10d_functional.all_reduce(input_, "sum", TAG, RANKSET, GROUP_SIZE)
else:
input_ = xm.all_reduce(xm.REDUCE_SUM, input_, groups=groups)
return input_