in deep_ep/buffer.py [0:0]
def combine(self, x: torch.Tensor, handle: Tuple,
topk_weights: Optional[torch.Tensor] = None,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:
"""
Combine (reduce) tokens (addition **without** weights) from different ranks, both intranode and internode
settings are supported.
Intranode kernels require all the ranks should be visible via NVLink.
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
index should be visible via RDMA.
Arguments:
x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks.
handle: a must-set communication handle, you can obtain this from the dispatch function.
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the tokens' top-k weights for reducing to its original ranks.
config: the performance tuning config.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
Returns:
recv_x: the reduced token from its dispatched ranks.
recv_topk_weights: the reduced top-k weights from its dispatch ranks.
event: the event after executing the kernel (valid only if `async_finish` is set).
"""
# Default config
config = self.get_combine_config(self.group_size) if config is None else config
# Internode
if self.runtime.get_num_rdma_ranks() > 1:
return self.internode_combine(x, handle, topk_weights, config, previous_event, async_finish, allocate_on_comm_stream)
# NOTES: the second `_` is for the sending side, so we should use the third one
rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle
# Launch the kernel
recv_x, recv_topk_weights, event = self.runtime.intranode_combine(
x, topk_weights,
src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config,
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return recv_x, recv_topk_weights, EventOverlap(event)