def traced_all_gather()

in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]


def traced_all_gather(tensor_list, tensor, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_gather.

  Let n := [Group Size]
  Calculate [Ring-B/W] = (n-1)/n * [Message Size] / [Kernel Time]
  Assuming equal tensor sizes.

  Args:
    tensor_list: Passed to torch.distributed.all_gather
    tensor: Passed to torch.distributed.all_gather
    group: Passed to torch.distributed.all_gather
    async_op: Passed to torch.distributed.all_gather

  Returns:
    Output of torch.distributed.all_gather
  """
  if _should_rank_record_comm(group):
    message_size = functools.reduce(
        lambda size, x: size + x.nelement() * x.element_size(), tensor_list, 0)
    _emit_call_description('all_gather', message_size, group)

  return torch.distributed.untraced_all_gather(
      tensor_list, tensor, group, async_op)