def rpc_register()

in graphlearn_torch/python/distributed/rpc.py [0:0]


def rpc_register(callee: RpcCalleeBase):
  r""" Register a callee for rpc requests only in the current role group,
  this method will block until all local and remote RPC processes of the
  current role group reach this method.
  """
  global _rpc_callee_id, _rpc_callee_pool
  with _rpc_callee_lock:
    callee_id = _rpc_callee_id
    _rpc_callee_id += 1
    if callee_id in _rpc_callee_pool:
      raise RuntimeError(f"'rpc_register': try to register with the "
                         f"callee id {callee_id} twice.")
    _rpc_callee_pool[callee_id] = callee

  current_role = get_context().role
  callee_ids = all_gather((current_role, callee_id))
  for name, (role, cid) in callee_ids.items():
    if role != current_role:
      raise RuntimeError(f"'rpc_register': get inconsistent role '{role}' "
                         f"from {name}, current role is '{current_role}'.")
    if cid != callee_id:
      raise RuntimeError(f"'rpc_register': get inconsistent callee id '{cid}' "
                         f"from {name}, current callee id is '{callee_id}'.")

  return callee_id