in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]
def _should_rank_record_comm(
group=None, peer_rank=None, root_rank=None, is_ring=True):
"""Decides whether a given torch.distributed collective should be recorded.
Args:
group: The torch process group (i.e. participating GPUs) in this collective.
peer_rank: In direct peer to peer operations, the global rank of the peer.
root_rank: The global rank of the root GPU, for collectives with a root.
as_ring: Whether the default NCCL implementation uses a ring algorithm.
Specifying 'peer_rank' and 'is_ring=True' are incompatible.
Returns:
Whether to record a descriptive NVTX marker, and possibly print a log trace.
"""
if not _is_current_process_in_group(group):
return False
if _TRACE_MODE == 'crossnode' and not _is_crossnode_comm(group, peer_rank):
return False
if not is_ring and root_rank is not None:
return torch.distributed.get_rank() == root_rank
return True