def _should_rank_record_comm()

in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]


def _should_rank_record_comm(
    group=None, peer_rank=None, root_rank=None, is_ring=True):
  """Decides whether a given torch.distributed collective should be recorded.

  Args:
    group: The torch process group (i.e. participating GPUs) in this collective.
    peer_rank: In direct peer to peer operations, the global rank of the peer.
    root_rank: The global rank of the root GPU, for collectives with a root.
    is_ring: Whether the default NCCL implementation uses a ring algorithm.
    Specifying 'peer_rank' and 'is_ring=True' are incompatible.

  Returns:
    Whether to record a descriptive NVTX marker, and possibly print a log trace.
  """
  if not _is_current_process_in_group(group):
    return False
  if _TRACE_MODE == 'crossnode' and not _is_crossnode_comm(group, peer_rank):
    return False
  if not is_ring and root_rank is not None:
    return torch.distributed.get_rank() == root_rank

  return True