in deep_ep/buffer.py [0:0]
def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
topk_weights: Optional[torch.Tensor] = None,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:
"""
Internode combine implementation, for more details, please refer to the `combine` docs.
Normally, you should not directly call this function.
"""
assert config is not None
# Unpack handle
is_combined_token_in_rank, \
_, _, \
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \
src_meta, send_rdma_head, send_nvl_head = handle
# Launch the kernel
combined_x, combined_topk_weights, event = self.runtime.internode_combine(
x, topk_weights,
src_meta, is_combined_token_in_rank,
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None),
async_finish, allocate_on_comm_stream)
return combined_x, combined_topk_weights, EventOverlap(event)