def traced_scatter()

in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]


def traced_scatter(
    tensor, scatter_list=None, src=0, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.scatter.

  Let T := sum([Send Kernel Time from Rank i] for i != src)
  Calculate [P2P-B/W] = [Message Size]/T

  Each of (n-1) ranks receives a message from the root.
  There is no (n-1)/n factor as we factor it in [Message Size].

  https://github.com/NVIDIA/nccl-tests/blob/1a5f551ffd6e/src/scatter.cu#L50
  https://github.com/pytorch/pytorch/blob/bfd995f0d6bf/torch/csrc/cuda/nccl.cpp#L1089
  """
  if _should_rank_record_comm(group, root_rank=src, is_ring=False):
    message_size = functools.reduce(
        lambda sz, x: sz + x.nelement() * x.element_size(), scatter_list, 0)
    message_size -= tensor.nelement() * tensor.element_size()

    _emit_call_description('scatter', message_size, group, root_rank=src)

  return torch.distributed.untraced_scatter(
      tensor, scatter_list, src, group, async_op)