def traced_reduce_scatter_tensor()

in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]


def traced_reduce_scatter_tensor(
    output,
    input,
    op=torch.distributed.ReduceOp.SUM,
    group=None,
    async_op=False):
  """Intercepts invocations of torch.distributed.reduce_scatter_tensor.

  Similar to 'traced_reduce_scatter'

  Args:
    output: Passed to torch.distributed.reduce_scatter_tensor
    input: Passed to torch.distributed.reduce_scatter_tensor
    op: Passed to torch.distributed.reduce_scatter_tensor
    group: Passed to torch.distributed.reduce_scatter_tensor
    async_op: Passed to torch.distributed.reduce_scatter_tensor

  Returns:
    Output of torch.distributed.reduce_scatter_tensor
  """

  if _should_rank_record_comm(group):
    message_size = output.nelement() * output.element_size()
    _emit_call_description('reduce_scatter', message_size, group)

  return torch.distributed.untraced_reduce_scatter_tensor(
      output, input, op, group, async_op)