def traced_reduce()

in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]


def traced_reduce(
    tensor, dst, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.reduce.

  Calculate [Ring-B/W] = [Message Size]/[Kernel Time] for large [Message Size]
  Also see 'traced_broadcast'
  """
  if _should_rank_record_comm(group, root_rank=dst):
    message_size = tensor.nelement() * tensor.element_size()
    _emit_call_description('reduce', message_size, group, root_rank=dst)

  return torch.distributed.untraced_reduce(tensor, dst, op, group, async_op)