in deep_ep/buffer.py [0:0]
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
handle: Optional[Tuple] = None,
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], List[int], Tuple, EventOverlap]:
"""
Dispatch tokens to different ranks, both intranode and internode settings are supported.
Intranode kernels require all the ranks should be visible via NVLink.
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
index should be visible via RDMA.
Arguments:
x: `torch.Tensor` or tuple of `torch.Tensor`, for the first type, the shape must be `[num_tokens, hidden]`,
and type must be `torch.bfloat16`; for the second type, the first element of the tuple must be shaped as
`[num_tokens, hidden]` with type `torch.float8_e4m3fn`, the second must be `[num_tokens, hidden // 128]`
(requiring divisible) with type `torch.float`.
handle: an optional communication handle, if set, the CPU will reuse the layout information to save some time.
num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.
num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA
rank (with the same GPU index), return `None` for intranode settings.
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token,
`-1` means no selections.
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
expert_alignment: align the number of tokens received by each local expert to this variable.
config: the performance tuning config.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
Returns:
recv_x: received tokens, the same type and tuple as the input `x`, but the number of tokens equals to the
received token count.
recv_topk_idx: received expert indices.
recv_topk_weights: received expert weights.
num_recv_tokens_per_expert_list: Python list shaped `[num_local_experts]`, the received token count by
each local expert, aligned to the input `expert_alignment`.
handle: the returned communication handle.
event: the event after executing the kernel (valid only if `async_finish` is set).
"""
# Default config
config = self.get_dispatch_config(self.group_size) if config is None else config
# Internode
if self.runtime.get_num_rdma_ranks() > 1:
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)
# Launch the kernel with cached or non-cached mode
x, x_scales = x if isinstance(x, tuple) else (x, None)
if handle is not None:
assert topk_idx is None and topk_weights is None
rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head = handle
num_recv_tokens = recv_src_idx.size(0)
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
x, x_scales, None, None,
None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
else:
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)