def traced_all_to_all_single()

in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]


def traced_all_to_all_single(
    output,
    input,
    output_split_sizes=None,
    input_split_sizes=None,
    group=None,
    async_op=False):
  """Intercepts invocations of torch.distributed.all_to_all_single.

  Similar to 'traced_all_to_all'
  """
  if _should_rank_record_comm(group):
    self_rank = torch.distributed.get_rank(group)

    if input_split_sizes is not None:
      self_slice = input_split_sizes[self_rank]
    else:
      self_slice = input.size(dim=0) / torch.distributed.get_world_size(group)

    slice_nelement = input.nelement() / input.size(dim=0)
    message_size = input.nelement() * input.element_size()
    message_size -= self_slice * slice_nelement * input.element_size()

    _emit_call_description('all_to_all_single', message_size, group)

  return torch.distributed.untraced_all_to_all_single(
      output, input, output_split_sizes, input_split_sizes, group, async_op)