def all_gather()

in graphlearn_torch/python/distributed/rpc.py [0:0]


def all_gather(obj, timeout=None):
  r""" Gathers objects only from the current role group in a list. This
  function blocks until all workers in the current role group have received
  the gathered results. The implementation of this method is refer to
  ``torch.distributed.rpc.api._all_gather``.
  """
  assert (
    _rpc_current_group_worker_names is not None
  ), "`_rpc_current_group_worker_names` is not initialized for `all_gather`."
  leader_name = sorted(_rpc_current_group_worker_names)[0]
  self_name = get_context().worker_name

  global _role_based_all_gather_sequence_id
  with _role_based_all_gather_dict_lock:
    sequence_id = _role_based_all_gather_sequence_id
    _role_based_all_gather_sequence_id += 1

  is_leader = leader_name == self_name
  if timeout is None:
    timeout = rpc.get_rpc_timeout()

  # Phase 1: Followers send it's object to the leader
  if is_leader:
    _role_based_gather_to_leader(sequence_id, self_name, obj)
  else:
    rpc.rpc_sync(
      leader_name,
      _role_based_gather_to_leader,
      args=(sequence_id, self_name, obj),
      timeout=timeout,
    )

  with _role_based_all_gather_dict_lock:
    states = _role_based_all_gather_sequence_id_to_states[sequence_id]
  states.proceed_signal.wait()

  # Phase 2: Leader broadcast gathered results to all followers
  # Leader's signal is the first to be unblocked, after receiving all
  # followers' data objects.
  if is_leader:
    worker_name_to_response_future_dict = {}
    for follower_name in _rpc_current_group_worker_names - {leader_name}:
      fut = rpc.rpc_async(
        follower_name,
        _role_based_broadcast_to_followers,
        args=(sequence_id, states.gathered_objects),
        timeout=timeout
      )
      worker_name_to_response_future_dict[follower_name] = fut

    errors = []
    for follower_name, fut in worker_name_to_response_future_dict.items():
      try:
        fut.wait()
      except RuntimeError as ex:
        errors.append((follower_name, ex))

    if errors:
      raise RuntimeError(
        f"Followers {[e[0] for e in errors]} timed out in all_gather "
        f"after {timeout:.2f} seconds. The first exception is {errors[0][1]}"
      )

  return states.gathered_objects