def traced_all_gather_into_tensor()

in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]


def traced_all_gather_into_tensor(
    output_tensor, input_tensor, group=None, async_op=False):
  """Intercepts invocations of torch.distributed.all_gather_into_tensor.

  Similar 'traced_all_gather'

  Args:
    output_tensor: Passed to torch.distributed.all_gather_into_tensor
    input_tensor: Passed to torch.distributed.all_gather_into_tensor
    group: Passed to torch.distributed.all_gather_into_tensor
    async_op: Passed to torch.distributed.all_gather_into_tensor

  Returns:
    Output of torch.distributed.all_gather_into_tensor
  """
  if _should_rank_record_comm(group):
    message_size = output_tensor.nelement() * output_tensor.element_size()
    _emit_call_description('all_gather', message_size, group)

  return torch.distributed.untraced_all_gather_into_tensor(
      output_tensor, input_tensor, group, async_op)