def traced_all_to_all()

in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]


def traced_all_to_all(
    output_tensor_list, input_tensor_list, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_to_all.

  Let S := sum([Message Size on Rank i] for i = 1..n) where n := [Group Size]
  Let T := [End of last Receive last rank] - [Start of first Send first rank]
  Calculate [Algo B/W] = S / T.

  There is no n/(n-1) correction factor as we factor it in [Message Size].

  https://github.com/NVIDIA/nccl-tests/blob/1a5f551ffd6e/src/alltoall.cu#L57
  https://github.com/pytorch/pytorch/blob/bfd995f0d6bf/torch/csrc/cuda/nccl.cpp#L911

  Args:
    output_tensor_list: Passed to torch.distributed.all_to_all.
    input_tensor_list: Passed to torch.distributed.all_to_all
    group: Passed to torch.distributed.all_to_all
    async_op: Passed to torch.distributed.all_to_all

  Returns:
    Output of torch.distributed.all_to_all
  """
  if _should_rank_record_comm(group):
    message_size = functools.reduce(
        lambda s, x: s + x.nelement() * x.element_size(), input_tensor_list, 0)

    # Omit bytes corresponding to send and receive on the same rank
    self_tensor = input_tensor_list[torch.distributed.get_rank(group)]
    message_size -= self_tensor.nelement() * self_tensor.element_size()

    _emit_call_description('all_to_all', message_size, group)

  return torch.distributed.untraced_all_to_all(
      output_tensor_list, input_tensor_list, group, async_op)