in graphlearn_torch/python/distributed/rpc.py [0:0]
def rpc_register(callee: RpcCalleeBase):
r""" Register a callee for rpc requests only in the current role group,
this method will block until all local and remote RPC processes of the
current role group reach this method.
"""
global _rpc_callee_id, _rpc_callee_pool
with _rpc_callee_lock:
callee_id = _rpc_callee_id
_rpc_callee_id += 1
if callee_id in _rpc_callee_pool:
raise RuntimeError(f"'rpc_register': try to register with the "
f"callee id {callee_id} twice.")
_rpc_callee_pool[callee_id] = callee
current_role = get_context().role
callee_ids = all_gather((current_role, callee_id))
for name, (role, cid) in callee_ids.items():
if role != current_role:
raise RuntimeError(f"'rpc_register': get inconsistent role '{role}' "
f"from {name}, current role is '{current_role}'.")
if cid != callee_id:
raise RuntimeError(f"'rpc_register': get inconsistent callee id '{cid}' "
f"from {name}, current callee id is '{callee_id}'.")
return callee_id