in graphlearn_torch/python/distributed/rpc.py [0:0]
def all_gather(obj, timeout=None):
r""" Gathers objects only from the current role group in a list. This
function blocks until all workers in the current role group have received
the gathered results. The implementation of this method is refer to
``torch.distributed.rpc.api._all_gather``.
"""
assert (
_rpc_current_group_worker_names is not None
), "`_rpc_current_group_worker_names` is not initialized for `all_gather`."
leader_name = sorted(_rpc_current_group_worker_names)[0]
self_name = get_context().worker_name
global _role_based_all_gather_sequence_id
with _role_based_all_gather_dict_lock:
sequence_id = _role_based_all_gather_sequence_id
_role_based_all_gather_sequence_id += 1
is_leader = leader_name == self_name
if timeout is None:
timeout = rpc.get_rpc_timeout()
# Phase 1: Followers send it's object to the leader
if is_leader:
_role_based_gather_to_leader(sequence_id, self_name, obj)
else:
rpc.rpc_sync(
leader_name,
_role_based_gather_to_leader,
args=(sequence_id, self_name, obj),
timeout=timeout,
)
with _role_based_all_gather_dict_lock:
states = _role_based_all_gather_sequence_id_to_states[sequence_id]
states.proceed_signal.wait()
# Phase 2: Leader broadcast gathered results to all followers
# Leader's signal is the first to be unblocked, after receiving all
# followers' data objects.
if is_leader:
worker_name_to_response_future_dict = {}
for follower_name in _rpc_current_group_worker_names - {leader_name}:
fut = rpc.rpc_async(
follower_name,
_role_based_broadcast_to_followers,
args=(sequence_id, states.gathered_objects),
timeout=timeout
)
worker_name_to_response_future_dict[follower_name] = fut
errors = []
for follower_name, fut in worker_name_to_response_future_dict.items():
try:
fut.wait()
except RuntimeError as ex:
errors.append((follower_name, ex))
if errors:
raise RuntimeError(
f"Followers {[e[0] for e in errors]} timed out in all_gather "
f"after {timeout:.2f} seconds. The first exception is {errors[0][1]}"
)
return states.gathered_objects