def traced_reduce_scatter()

in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]


def traced_reduce_scatter(
    output,
    input_list,
    op=torch.distributed.ReduceOp.SUM,
    group=None,
    async_op=False):
  """Intercepts invocations of torch.distributed.reduce_scatter.

  Let n := [Group Size].
  Calculate [Ring-B/W] = (n-1)/n * [Message Size]/[Kernel Time]
  Assumes equal tensor sizes. It's the same as first half of ring All-Reduce.
  """
  if _should_rank_record_comm(group):
    message_size = output.nelement() * output.element_size()
    _emit_call_description('reduce_scatter', message_size, group)

  return torch.distributed.untraced_reduce_scatter(
      output, input_list, op, group, async_op)