maga_transformer/ops/comm/nccl_op.py (25 lines of code) (raw):
from typing import List
import torch
from maga_transformer.distribute.worker_info import g_parallel_info, g_master_info
def singleton(cls):
_instance = {}
def inner():
if cls not in _instance:
_instance[cls] = cls()
return _instance[cls]
return inner
@singleton
class NcclOp():
def __init__(self):
super().__init__()
self.ft_op_ = torch.classes.RtpLlm.NcclOp( # type: ignore
g_parallel_info.tp_size,
g_parallel_info.pp_size,
g_master_info.ip,
g_master_info.nccl_op_port)
def broadcast_tp(self, tensors: List[torch.Tensor], root: int = 0):
self.ft_op_.broadcast_tp(tensors, root, True)
def barrier(self, device: torch.device, root: int = 0):
dummy_tensor = torch.zeros(1, device=device)
self.ft_op_.broadcast_tp([dummy_tensor], root, False)
torch.cuda.current_stream().synchronize()