in deep_ep/buffer.py [0:0]
def __init__(self, group: dist.ProcessGroup,
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
low_latency_mode: bool = False, num_qps_per_rank: int = 12) -> None:
"""
Initialize the communication buffer.
Arguments:
group: the communication group.
num_nvl_bytes: the buffer size for intranode NVLink communication.
num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication.
low_latency_mode: whether to enable low-latency mode.
num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals
to the number of local experts.
"""
# Initialize the CPP runtime
self.rank = group.rank()
self.group_size = group.size()
self.group = group
self.num_nvl_bytes = num_nvl_bytes
self.num_rdma_bytes = num_rdma_bytes
self.low_latency_mode = low_latency_mode
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode)
# Synchronize device IDs
device_ids = [None, ] * self.group_size
local_device_id = self.runtime.get_local_device_id()
dist.all_gather_object(device_ids, local_device_id, group)
# Synchronize IPC handles
ipc_handles = [None, ] * self.group_size
local_ipc_handle = self.runtime.get_local_ipc_handle()
dist.all_gather_object(ipc_handles, local_ipc_handle, group)
# Synchronize NVSHMEM unique IDs
root_unique_id = None
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
# Enable IBGDA
assert num_qps_per_rank > 0
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
os.environ['NVSHMEM_QP_DEPTH'] = '1024'
# NOTES: NVSHMEM initialization requires at least 256 MiB
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'
# Synchronize using the root ID
nvshmem_unique_ids = [None, ] * self.group_size
if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0):
root_unique_id = self.runtime.get_local_nvshmem_unique_id()
dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group)
root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)]
# Make CPP runtime available
self.runtime.sync(device_ids, ipc_handles, root_unique_id)
assert self.runtime.is_available()