in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]
def traced_all_to_all_single(
output,
input,
output_split_sizes=None,
input_split_sizes=None,
group=None,
async_op=False):
"""Intercepts invocations of torch.distributed.all_to_all_single.
Similar to 'traced_all_to_all'
Args:
output: Passed to torch.distributed.all_to_all_single.
input: Passed to torch.distributed.all_to_all_single
output_split_sizes: Passed to torch.distributed.all_to_all_single.
input_split_sizes: Passed to torch.distributed.all_to_all_single
group: Passed to torch.distributed.all_to_all_single
async_op: Passed to torch.distributed.all_to_all_single
Returns:
Output of torch.distributed.all_to_all_single
"""
if _should_rank_record_comm(group):
self_rank = torch.distributed.get_rank(group)
if input_split_sizes is not None:
self_slice = input_split_sizes[self_rank]
else:
self_slice = input.size(dim=0) / torch.distributed.get_world_size(group)
slice_nelement = input.nelement() / input.size(dim=0)
message_size = input.nelement() * input.element_size()
message_size -= self_slice * slice_nelement * input.element_size()
_emit_call_description('all_to_all_single', message_size, group)
return torch.distributed.untraced_all_to_all_single(
output, input, output_split_sizes, input_split_sizes, group, async_op)