in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]
def traced_reduce_scatter(
output,
input_list,
op=torch.distributed.ReduceOp.SUM,
group=None,
async_op=False):
"""Intercepts invocations of torch.distributed.reduce_scatter.
Let n := [Group Size].
Calculate [Ring-B/W] = (n-1)/n * [Message Size]/[Kernel Time]
Assumes equal tensor sizes. It's the same as first half of ring All-Reduce.
Args:
output: Passed to torch.distributed.reduce_scatter
input_list: Passed to torch.distributed.reduce_scatter
op: Passed to torch.distributed.reduce_scatter
group: Passed to torch.distributed.reduce_scatter
async_op: Passed to torch.distributed.reduce_scatter
Returns:
Output of torch.distributed.reduce_scatter
"""
if _should_rank_record_comm(group):
message_size = output.nelement() * output.element_size()
_emit_call_description('reduce_scatter', message_size, group)
return torch.distributed.untraced_reduce_scatter(
output, input_list, op, group, async_op)