sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [79:396]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def _shunt_torch_communication_calls():
  """Replaces torch.distributed.<target_collective> with a traced version.
  """
  target_collectives = [
      'barrier',
      'broadcast_object_list',
      'broadcast',
      'gather',
      'scatter',
      'reduce',
      'reduce_scatter',
      'reduce_scatter_tensor',
      'all_reduce',
      'all_gather',
      'all_gather_into_tensor',
      'all_to_all',
      'all_to_all_single',
      'batch_isend_irecv',
      'isend',
      'irecv',
      'send',
      'recv',
  ]

  this_module = sys.modules[__name__]
  for collective in target_collectives:
    original_fn = getattr(torch.distributed, collective)
    replaced_fn = getattr(this_module, 'traced_' + collective)
    setattr(torch.distributed, 'untraced_' + collective, original_fn)
    setattr(torch.distributed, collective, replaced_fn)


def _shunt_torch_communication_objects():
  original_p2p = torch.distributed.P2POp
  setattr(torch.distributed, 'UntracedP2POp', original_p2p)
  setattr(torch.distributed, 'P2POp', _TracedP2POp)


# Each 'traced_<comm>' defines a 'message_size' to compute B/W.
# Ref https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md

# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_barrier(group=None, async_op=False, device_ids=None):
  """Intercepts invocations of torch.distributed.barrier.
  """
  if _should_rank_record_comm(group):
    _emit_call_description('barrier', message_size=1, group=group)

  return torch.distributed.untraced_barrier(group, async_op, device_ids)


# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_broadcast_object_list(object_list, src=0, group=None, device=None):
  """Intercepts invocations of torch.distributed.broadcast_object_list.

  Converts objects to tensor data using the pickle library. Then conducts a
  torch.distributed.broadcast call.
  """

  if _should_rank_record_comm(group, root_rank=src):
    message_size = 0
    for obj in object_list:
      # Note: This computation is sadly redundant with underlying call :(
      # For now we don't expect this invocation to be in critical path.
      buf = io.BytesIO()
      pickle.Pickler(buf).dump(obj)
      message_size += buf.getbuffer().nbytes
    _emit_call_description(
        'broadcast_object_list', message_size, group, root_rank=src)

  return torch.distributed.untraced_broadcast_object_list(
      object_list, src, group, device)


# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_broadcast(tensor, src, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.broadcast.

  Calculate [Ring-B/W] = [Message Size]/[Kernel Time] for large [Message Size]

  https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf
  """
  if _should_rank_record_comm(group, root_rank=src):
    message_size = tensor.nelement() * tensor.element_size()
    _emit_call_description('broadcast', message_size, group, root_rank=src)

  return torch.distributed.untraced_broadcast(
      tensor, src, group, async_op)


# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_gather(
    tensor, gather_list=None, dst=0, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.gather.

  Let T := sum([Receive Kernel Time from Rank i] for i != dst)
  Calculate [P2P-B/W] = [Message Size]/T

  Each of (n-1) ranks sends a message to the root.

  Note that any correction factors for the bus bandwidth (e.g. [n-1]/n) depend
  on the *definition* of 'Message Size'. In some cases, such as for 'gather', we
  define 'Message Size' so as to omit the size of data that is already local
  to the destination GPU for the 'gather' operation. In this case, no correction
  factor is needed. In NCCL tests, they assume all ranks send equal sized
  messages and include this size of data already resident on the destination
  GPU. Thus, in there case you see a (n-1)/n correction factor on calculating
  the bus bandwidth. In general, the goal of computing the bus bandwidth is
  to compare data transfer rates on the bus relative to peak bus bandwidth.
  See https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md.

  https://github.com/NVIDIA/nccl-tests/blob/1a5f551ffd6e/src/gather.cu#L54
  https://github.com/pytorch/pytorch/blob/bfd995f0d6bf/torch/csrc/cuda/nccl.cpp#L1040
  """
  if _should_rank_record_comm(group, root_rank=dst, is_ring=False):
    message_size = functools.reduce(
        lambda sz, x: sz + x.nelement() * x.element_size(), gather_list, 0)
    message_size -= tensor.nelement() * tensor.element_size()

    _emit_call_description('gather', message_size, group, root_rank=dst)

  return torch.distributed.untraced_gather(
      tensor, gather_list, dst, group, async_op)


# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_scatter(
    tensor, scatter_list=None, src=0, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.scatter.

  Let T := sum([Send Kernel Time from Rank i] for i != src)
  Calculate [P2P-B/W] = [Message Size]/T

  Each of (n-1) ranks receives a message from the root.
  There is no (n-1)/n factor as we factor it in [Message Size].

  https://github.com/NVIDIA/nccl-tests/blob/1a5f551ffd6e/src/scatter.cu#L50
  https://github.com/pytorch/pytorch/blob/bfd995f0d6bf/torch/csrc/cuda/nccl.cpp#L1089
  """
  if _should_rank_record_comm(group, root_rank=src, is_ring=False):
    message_size = functools.reduce(
        lambda sz, x: sz + x.nelement() * x.element_size(), scatter_list, 0)
    message_size -= tensor.nelement() * tensor.element_size()

    _emit_call_description('scatter', message_size, group, root_rank=src)

  return torch.distributed.untraced_scatter(
      tensor, scatter_list, src, group, async_op)


# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_reduce(
    tensor, dst, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.reduce.

  Calculate [Ring-B/W] = [Message Size]/[Kernel Time] for large [Message Size]
  Also see 'traced_broadcast'
  """
  if _should_rank_record_comm(group, root_rank=dst):
    message_size = tensor.nelement() * tensor.element_size()
    _emit_call_description('reduce', message_size, group, root_rank=dst)

  return torch.distributed.untraced_reduce(tensor, dst, op, group, async_op)


# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_reduce_scatter(
    output,
    input_list,
    op=torch.distributed.ReduceOp.SUM,
    group=None,
    async_op=False):
  """Intercepts invocations of torch.distributed.reduce_scatter.

  Let n := [Group Size].
  Calculate [Ring-B/W] = (n-1)/n * [Message Size]/[Kernel Time]
  Assumes equal tensor sizes. It's the same as first half of ring All-Reduce.
  """
  if _should_rank_record_comm(group):
    message_size = output.nelement() * output.element_size()
    _emit_call_description('reduce_scatter', message_size, group)

  return torch.distributed.untraced_reduce_scatter(
      output, input_list, op, group, async_op)


# pylint: disable=redefined-builtin,g-doc-args,g-doc-return-or-yield
def traced_reduce_scatter_tensor(
    output,
    input,
    op=torch.distributed.ReduceOp.SUM,
    group=None,
    async_op=False):
  """Intercepts invocations of torch.distributed.reduce_scatter_tensor.

  Similar to 'traced_reduce_scatter'
  """

  if _should_rank_record_comm(group):
    message_size = output.nelement() * output.element_size()
    _emit_call_description('reduce_scatter', message_size, group)

  return torch.distributed.untraced_reduce_scatter_tensor(
      output, input, op, group, async_op)


# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_all_reduce(
    tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_reduce.

  Let n := [Group Size]
  Calculate [Ring-B/W] = 2(n-1)/n * [Message Size] / [Kernel Time]

  https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf
  """
  if _should_rank_record_comm(group):
    message_size = tensor.nelement() * tensor.element_size()
    _emit_call_description('all_reduce', message_size, group)

  return torch.distributed.untraced_all_reduce(
      tensor, op, group, async_op)


# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_all_gather(tensor_list, tensor, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_gather.

  Let n := [Group Size]
  Calculate [Ring-B/W] = (n-1)/n * [Message Size] / [Kernel Time]
  Assuming equal tensor sizes.
  """
  if _should_rank_record_comm(group):
    message_size = functools.reduce(
        lambda size, x: size + x.nelement() * x.element_size(), tensor_list, 0)
    _emit_call_description('all_gather', message_size, group)

  return torch.distributed.untraced_all_gather(
      tensor_list, tensor, group, async_op)


# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_all_gather_into_tensor(
    output_tensor, input_tensor, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_gather_into_tensor.

  Similar 'traced_all_gather'
  """
  if _should_rank_record_comm(group):
    message_size = output_tensor.nelement() * output_tensor.element_size()
    _emit_call_description('all_gather', message_size, group)

  return torch.distributed.untraced_all_gather_into_tensor(
      output_tensor, input_tensor, group, async_op)


# Note: The TCP Direct team intends to implement a custom version of AllToAll.
# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_all_to_all(
    output_tensor_list, input_tensor_list, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_to_all.

  Let S := sum([Message Size on Rank i] for i = 1..n) where n := [Group Size]
  Let T := [End of last Receive last rank] - [Start of first Send first rank]
  Calculate [Algo B/W] = S / T.

  There is no n/(n-1) correction factor as we factor it in [Message Size].

  https://github.com/NVIDIA/nccl-tests/blob/1a5f551ffd6e/src/alltoall.cu#L57
  https://github.com/pytorch/pytorch/blob/bfd995f0d6bf/torch/csrc/cuda/nccl.cpp#L911
  """
  if _should_rank_record_comm(group):
    message_size = functools.reduce(
        lambda s, x: s + x.nelement() * x.element_size(), input_tensor_list, 0)

    # Omit bytes corresponding to send and receive on the same rank
    self_tensor = input_tensor_list[torch.distributed.get_rank(group)]
    message_size -= self_tensor.nelement() * self_tensor.element_size()

    _emit_call_description('all_to_all', message_size, group)

  return torch.distributed.untraced_all_to_all(
      output_tensor_list, input_tensor_list, group, async_op)


# pylint: disable=g-doc-args,g-doc-return-or-yield,redefined-builtin
def traced_all_to_all_single(
    output,
    input,
    output_split_sizes=None,
    input_split_sizes=None,
    group=None,
    async_op=False):
  """Intercepts invocations of torch.distributed.all_to_all_single.

  Similar to 'traced_all_to_all'
  """
  if _should_rank_record_comm(group):
    self_rank = torch.distributed.get_rank(group)

    if input_split_sizes is not None:
      self_slice = input_split_sizes[self_rank]
    else:
      self_slice = input.size(dim=0) / torch.distributed.get_world_size(group)

    slice_nelement = input.nelement() / input.size(dim=0)
    message_size = input.nelement() * input.element_size()
    message_size -= self_slice * slice_nelement * input.element_size()

    _emit_call_description('all_to_all_single', message_size, group)

  return torch.distributed.untraced_all_to_all_single(
      output, input, output_split_sizes, input_split_sizes, group, async_op)


# Note: Each send and receive occurs on indepenent CUDA streams
# pylint: disable=g-doc-args,g-doc-return-or-yield
def traced_batch_isend_irecv(p2p_op_list):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



sample_workloads/megatron-gke/docker/monitor_collectives.py [54:486]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def _shunt_torch_communication_calls():
  """Replaces torch.distributed.<target_collective> with a traced version.
  """
  target_collectives = [
      'barrier',
      'broadcast_object_list',
      'broadcast',
      'gather',
      'scatter',
      'reduce',
      'reduce_scatter',
      'reduce_scatter_tensor',
      'all_reduce',
      'all_gather',
      'all_gather_into_tensor',
      'all_to_all',
      'all_to_all_single',
      'batch_isend_irecv',
      'isend',
      'irecv',
      'send',
      'recv',
  ]

  this_module = sys.modules[__name__]
  for collective in target_collectives:
    original_fn = getattr(torch.distributed, collective)
    replaced_fn = getattr(this_module, 'traced_' + collective)
    setattr(torch.distributed, 'untraced_' + collective, original_fn)
    setattr(torch.distributed, collective, replaced_fn)


def _shunt_torch_communication_objects():
  original_p2p = torch.distributed.P2POp
  setattr(torch.distributed, 'UntracedP2POp', original_p2p)
  setattr(torch.distributed, 'P2POp', _TracedP2POp)


# Each 'traced_<comm>' defines a 'message_size' to compute B/W.
# Ref https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md


# For each 'traced_<comm>' the corresponding API docs (including args, return)
# are available at https://pytorch.org/docs/stable/distributed.html


def traced_barrier(group=None, async_op=False, device_ids=None):
  """Intercepts invocations of torch.distributed.barrier.

  Args:
    group: Passed to torch.distributed.barrier
    async_op: Passed to torch.distributed.barrier
    device_ids: Passed to torch.distributed.barrier

  Returns:
    Output of torch.distributed.barrier
  """
  if _should_rank_record_comm(group):
    _emit_call_description('barrier', message_size=1, group=group)

  return torch.distributed.untraced_barrier(group, async_op, device_ids)


def traced_broadcast_object_list(object_list, src=0, group=None, device=None):
  """Intercepts invocations of torch.distributed.broadcast_object_list.

  Converts objects to tensor data using the pickle library. Then conducts a
  torch.distributed.broadcast call.

  Args:
    object_list: Passed to torch.distributed.broadcast_object_list
    src: Passed to torch.distributed.broadcast_object_list
    group: Passed to torch.distributed.broadcast_object_list
    device: Passed to torch.distributed.broadcast_object_list

  Returns:
    Output of torch.distributed.broadcast_object_list
  """

  if _should_rank_record_comm(group, root_rank=src):
    message_size = 0
    for obj in object_list:
      # Note: This computation is sadly redundant with underlying call :(
      # For now we don't expect this invocation to be in critical path.
      buf = io.BytesIO()
      pickle.Pickler(buf).dump(obj)
      message_size += buf.getbuffer().nbytes
    _emit_call_description(
        'broadcast_object_list', message_size, group, root_rank=src)

  return torch.distributed.untraced_broadcast_object_list(
      object_list, src, group, device)


def traced_broadcast(tensor, src, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.broadcast.

  Calculate [Ring-B/W] = [Message Size]/[Kernel Time] for large [Message Size]

  https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf

  Args:
    tensor: Passed to torch.distributed.broadcast
    src: Passed to torch.distributed.broadcast
    group: Passed to torch.distributed.broadcast
    async_op: Passed to torch.distributed.broadcast

  Returns:
    Output of torch.distributed.broadcast
  """
  if _should_rank_record_comm(group, root_rank=src):
    message_size = tensor.nelement() * tensor.element_size()
    _emit_call_description('broadcast', message_size, group, root_rank=src)

  return torch.distributed.untraced_broadcast(
      tensor, src, group, async_op)


def traced_gather(
    tensor, gather_list=None, dst=0, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.gather.

  Let T := sum([Receive Kernel Time from Rank i] for i != dst)
  Calculate [P2P-B/W] = [Message Size]/T

  Each of (n-1) ranks sends a message to the root.

  Note that any correction factors for the bus bandwidth (e.g. [n-1]/n) depend
  on the *definition* of 'Message Size'. In some cases, such as for 'gather', we
  define 'Message Size' so as to omit the size of data that is already local
  to the destination GPU for the 'gather' operation. In this case, no correction
  factor is needed. In NCCL tests, they assume all ranks send equal sized
  messages and include this size of data already resident on the destination
  GPU. Thus, in there case you see a (n-1)/n correction factor on calculating
  the bus bandwidth. In general, the goal of computing the bus bandwidth is
  to compare data transfer rates on the bus relative to peak bus bandwidth.
  See https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md.

  https://github.com/NVIDIA/nccl-tests/blob/1a5f551ffd6e/src/gather.cu#L54
  https://github.com/pytorch/pytorch/blob/bfd995f0d6bf/torch/csrc/cuda/nccl.cpp#L1040

  Args:
    tensor: Passed to torch.distributed.gather
    gather_list: Passed to torch.distributed.gather
    dst: Passed to torch.distributed.gather
    group: Passed to torch.distributed.gather
    async_op: Passed to torch.distributed.gather

  Returns:
    Output of torch.distributed.gather
  """
  if _should_rank_record_comm(group, root_rank=dst, is_ring=False):
    message_size = functools.reduce(
        lambda sz, x: sz + x.nelement() * x.element_size(), gather_list, 0)
    message_size -= tensor.nelement() * tensor.element_size()

    _emit_call_description('gather', message_size, group, root_rank=dst)

  return torch.distributed.untraced_gather(
      tensor, gather_list, dst, group, async_op)


def traced_scatter(
    tensor, scatter_list=None, src=0, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.scatter.

  Let T := sum([Send Kernel Time from Rank i] for i != src)
  Calculate [P2P-B/W] = [Message Size]/T

  Each of (n-1) ranks receives a message from the root.
  There is no (n-1)/n factor as we factor it in [Message Size].

  https://github.com/NVIDIA/nccl-tests/blob/1a5f551ffd6e/src/scatter.cu#L50
  https://github.com/pytorch/pytorch/blob/bfd995f0d6bf/torch/csrc/cuda/nccl.cpp#L1089

  Args:
    tensor: Passed to torch.distributed.scatter.
    scatter_list: Passed to torch.distributed.scatter.
    src: Passed to torch.distributed.scatter
    group: Passed to torch.distributed.scatter
    async_op: Passed to torch.distributed.scatter

  Returns:
    Output of torch.distributed.scatter
  """
  if _should_rank_record_comm(group, root_rank=src, is_ring=False):
    message_size = functools.reduce(
        lambda sz, x: sz + x.nelement() * x.element_size(), scatter_list, 0)
    message_size -= tensor.nelement() * tensor.element_size()

    _emit_call_description('scatter', message_size, group, root_rank=src)

  return torch.distributed.untraced_scatter(
      tensor, scatter_list, src, group, async_op)


def traced_reduce(
    tensor, dst, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.reduce.

  Calculate [Ring-B/W] = [Message Size]/[Kernel Time] for large [Message Size]
  Also see 'traced_broadcast'

  Args:
    tensor: Passed to torch.distributed.reduce
    dst: Passed to torch.distributed.reduce
    op: Passed to torch.distributed.reduce
    group: Passed to torch.distributed.reduce
    async_op: Passed to torch.distributed.reduce

  Returns:
    Output of torch.distributed.reduce
  """
  if _should_rank_record_comm(group, root_rank=dst):
    message_size = tensor.nelement() * tensor.element_size()
    _emit_call_description('reduce', message_size, group, root_rank=dst)

  return torch.distributed.untraced_reduce(tensor, dst, op, group, async_op)


def traced_reduce_scatter(
    output,
    input_list,
    op=torch.distributed.ReduceOp.SUM,
    group=None,
    async_op=False):
  """Intercepts invocations of torch.distributed.reduce_scatter.

  Let n := [Group Size].
  Calculate [Ring-B/W] = (n-1)/n * [Message Size]/[Kernel Time]
  Assumes equal tensor sizes. It's the same as first half of ring All-Reduce.

  Args:
    output: Passed to torch.distributed.reduce_scatter
    input_list: Passed to torch.distributed.reduce_scatter
    op: Passed to torch.distributed.reduce_scatter
    group: Passed to torch.distributed.reduce_scatter
    async_op: Passed to torch.distributed.reduce_scatter

  Returns:
    Output of torch.distributed.reduce_scatter
  """
  if _should_rank_record_comm(group):
    message_size = output.nelement() * output.element_size()
    _emit_call_description('reduce_scatter', message_size, group)

  return torch.distributed.untraced_reduce_scatter(
      output, input_list, op, group, async_op)


# pylint: disable=redefined-builtin
def traced_reduce_scatter_tensor(
    output,
    input,
    op=torch.distributed.ReduceOp.SUM,
    group=None,
    async_op=False):
  """Intercepts invocations of torch.distributed.reduce_scatter_tensor.

  Similar to 'traced_reduce_scatter'

  Args:
    output: Passed to torch.distributed.reduce_scatter_tensor
    input: Passed to torch.distributed.reduce_scatter_tensor
    op: Passed to torch.distributed.reduce_scatter_tensor
    group: Passed to torch.distributed.reduce_scatter_tensor
    async_op: Passed to torch.distributed.reduce_scatter_tensor

  Returns:
    Output of torch.distributed.reduce_scatter_tensor
  """

  if _should_rank_record_comm(group):
    message_size = output.nelement() * output.element_size()
    _emit_call_description('reduce_scatter', message_size, group)

  return torch.distributed.untraced_reduce_scatter_tensor(
      output, input, op, group, async_op)


def traced_all_reduce(
    tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_reduce.

  Let n := [Group Size]
  Calculate [Ring-B/W] = 2(n-1)/n * [Message Size] / [Kernel Time]

  https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf

  Args:
    tensor: Passed to torch.distributed.all_reduce
    op: Passed to torch.distributed.all_reduce
    group: Passed to torch.distributed.all_reduce
    async_op: Passed to torch.distributed.all_reduce

  Returns:
    Output of torch.distributed.all_reduce
  """
  if _should_rank_record_comm(group):
    message_size = tensor.nelement() * tensor.element_size()
    _emit_call_description('all_reduce', message_size, group)

  return torch.distributed.untraced_all_reduce(
      tensor, op, group, async_op)


def traced_all_gather(tensor_list, tensor, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_gather.

  Let n := [Group Size]
  Calculate [Ring-B/W] = (n-1)/n * [Message Size] / [Kernel Time]
  Assuming equal tensor sizes.

  Args:
    tensor_list: Passed to torch.distributed.all_gather
    tensor: Passed to torch.distributed.all_gather
    group: Passed to torch.distributed.all_gather
    async_op: Passed to torch.distributed.all_gather

  Returns:
    Output of torch.distributed.all_gather
  """
  if _should_rank_record_comm(group):
    message_size = functools.reduce(
        lambda size, x: size + x.nelement() * x.element_size(), tensor_list, 0)
    _emit_call_description('all_gather', message_size, group)

  return torch.distributed.untraced_all_gather(
      tensor_list, tensor, group, async_op)


def traced_all_gather_into_tensor(
    output_tensor, input_tensor, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_gather_into_tensor.

  Similar 'traced_all_gather'

  Args:
    output_tensor: Passed to torch.distributed.all_gather_into_tensor
    input_tensor: Passed to torch.distributed.all_gather_into_tensor
    group: Passed to torch.distributed.all_gather_into_tensor
    async_op: Passed to torch.distributed.all_gather_into_tensor

  Returns:
    Output of torch.distributed.all_gather_into_tensor
  """
  if _should_rank_record_comm(group):
    message_size = output_tensor.nelement() * output_tensor.element_size()
    _emit_call_description('all_gather', message_size, group)

  return torch.distributed.untraced_all_gather_into_tensor(
      output_tensor, input_tensor, group, async_op)


# Note: The TCP Direct team intends to implement a custom version of AllToAll.
def traced_all_to_all(
    output_tensor_list, input_tensor_list, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_to_all.

  Let S := sum([Message Size on Rank i] for i = 1..n) where n := [Group Size]
  Let T := [End of last Receive last rank] - [Start of first Send first rank]
  Calculate [Algo B/W] = S / T.

  There is no n/(n-1) correction factor as we factor it in [Message Size].

  https://github.com/NVIDIA/nccl-tests/blob/1a5f551ffd6e/src/alltoall.cu#L57
  https://github.com/pytorch/pytorch/blob/bfd995f0d6bf/torch/csrc/cuda/nccl.cpp#L911

  Args:
    output_tensor_list: Passed to torch.distributed.all_to_all.
    input_tensor_list: Passed to torch.distributed.all_to_all
    group: Passed to torch.distributed.all_to_all
    async_op: Passed to torch.distributed.all_to_all

  Returns:
    Output of torch.distributed.all_to_all
  """
  if _should_rank_record_comm(group):
    message_size = functools.reduce(
        lambda s, x: s + x.nelement() * x.element_size(), input_tensor_list, 0)

    # Omit bytes corresponding to send and receive on the same rank
    self_tensor = input_tensor_list[torch.distributed.get_rank(group)]
    message_size -= self_tensor.nelement() * self_tensor.element_size()

    _emit_call_description('all_to_all', message_size, group)

  return torch.distributed.untraced_all_to_all(
      output_tensor_list, input_tensor_list, group, async_op)


def traced_all_to_all_single(
    output,
    input,
    output_split_sizes=None,
    input_split_sizes=None,
    group=None,
    async_op=False):
  """Intercepts invocations of torch.distributed.all_to_all_single.

  Similar to 'traced_all_to_all'

  Args:
    output: Passed to torch.distributed.all_to_all_single.
    input: Passed to torch.distributed.all_to_all_single
    output_split_sizes: Passed to torch.distributed.all_to_all_single.
    input_split_sizes: Passed to torch.distributed.all_to_all_single
    group: Passed to torch.distributed.all_to_all_single
    async_op: Passed to torch.distributed.all_to_all_single

  Returns:
    Output of torch.distributed.all_to_all_single
  """
  if _should_rank_record_comm(group):
    self_rank = torch.distributed.get_rank(group)

    if input_split_sizes is not None:
      self_slice = input_split_sizes[self_rank]
    else:
      self_slice = input.size(dim=0) / torch.distributed.get_world_size(group)

    slice_nelement = input.nelement() / input.size(dim=0)
    message_size = input.nelement() * input.element_size()
    message_size -= self_slice * slice_nelement * input.element_size()

    _emit_call_description('all_to_all_single', message_size, group)

  return torch.distributed.untraced_all_to_all_single(
      output, input, output_split_sizes, input_split_sizes, group, async_op)


# Note: Each send and receive occurs on indepenent CUDA streams
def traced_batch_isend_irecv(p2p_op_list):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



