def traced_batch_isend_irecv()

in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]


def traced_batch_isend_irecv(p2p_op_list):
  """Intercepts invocations of torch.distributed.batch_isend_irecv.

  Calculate [P2P-B/W] = [Message Size]/[Kernel Time] for each send and recv.
  """
  correlation_id = str(uuid.uuid4())
  for p2p in p2p_op_list:
    if _SHOULD_PRINT:
      print(f"Num p2p ops in batch: {len(p2p_op_list)}")
    if _should_rank_record_comm(p2p.group, peer_rank=p2p.peer, is_ring=False):
      api = 'send' if p2p.op == torch.distributed.untraced_isend else 'recv'

      message_size = p2p.tensor.nelement() * p2p.tensor.element_size()
      _emit_call_description(api, message_size, group=p2p.group, peer_rank=p2p.peer, correlation_id=correlation_id)

  return torch.distributed.untraced_batch_isend_irecv(p2p_op_list)