def _is_crossnode_comm()

in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]


def _is_crossnode_comm(group=None, peer_rank=None):
  """Whether this collective involves communication across nodes.

  Args:
    group: The torch process group (i.e. participating GPUs) in this collective.
    peer_rank: In direct peer to peer operations, the global rank of the peer.

  Returns:
    Whether this collective involves communications across nodes.
  """
  count_per_node = torch.cuda.device_count()

  if peer_rank is not None:
    this_node = int(torch.distributed.get_rank() / count_per_node)
    peer_node = int(peer_rank / count_per_node)
    return this_node != peer_node
  else:
    if group is not None:
      ranks = torch.distributed.get_process_group_ranks(group=group)
    else:
      ranks = [*range(torch.distributed.get_world_size())]

    nodes = list(map(lambda rank: int(rank / count_per_node), ranks))
    return any([node != nodes[0] for node in nodes])