maga_transformer/utils/nccl_util.py (9 lines of code) (raw):
import torch
import torch.distributed as dist
from maga_transformer.distribute.worker_info import g_parallel_info
def all_gather_tp(output: torch.Tensor) -> torch.Tensor:
tensor_list = [torch.empty_like(output) for _ in range(g_parallel_info.tp_size)]
tensor_list[g_parallel_info.tp_rank] = output
dist.all_gather(tensor_list, output)
output = torch.cat(tensor_list, dim=output.dim() - 1).contiguous()
return output