def traced_all_reduce()

in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]


def traced_all_reduce(
    tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_reduce.

  Let n := [Group Size]
  Calculate [Ring-B/W] = 2(n-1)/n * [Message Size] / [Kernel Time]

  https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf

  Args:
    tensor: Passed to torch.distributed.all_reduce
    op: Passed to torch.distributed.all_reduce
    group: Passed to torch.distributed.all_reduce
    async_op: Passed to torch.distributed.all_reduce

  Returns:
    Output of torch.distributed.all_reduce
  """
  if _should_rank_record_comm(group):
    message_size = tensor.nelement() * tensor.element_size()
    _emit_call_description('all_reduce', message_size, group)

  return torch.distributed.untraced_all_reduce(
      tensor, op, group, async_op)