in deep_ep/buffer.py [0:0]
def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
handle: Optional[Tuple] = None,
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], List[int], Tuple, EventOverlap]:
"""
Internode dispatch implementation, for more details, please refer to the `dispatch` docs.
Normally, you should not directly call this function.
"""
assert config is not None
# Launch the kernel with cached or non-cached mode
x, x_scales = x if isinstance(x, tuple) else (x, None)
if handle is not None:
assert topk_idx is None and topk_weights is None
is_token_in_rank, \
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
recv_src_meta, send_rdma_head, send_nvl_head = handle
num_recv_tokens = recv_src_meta.size(0)
num_rdma_recv_tokens = send_nvl_head.size(0)
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = self.runtime.internode_dispatch(
x, x_scales, topk_idx, topk_weights,
None, None, is_token_in_rank, None,
num_recv_tokens, num_rdma_recv_tokens,
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
else:
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, \
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
recv_src_meta, send_rdma_head, send_nvl_head, event = self.runtime.internode_dispatch(
x, x_scales, topk_idx, topk_weights,
num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
0, 0, None, None, None, None,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
handle = (is_token_in_rank,
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix,
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
recv_src_meta, send_rdma_head, send_nvl_head)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)