def init_rpc()

in graphlearn_torch/python/distributed/rpc.py [0:0]


def init_rpc(master_addr: str,
             master_port: int,
             num_rpc_threads: int = 16,
             rpc_timeout: float = 180,
             is_dynamic: bool = False):
  r""" Initialize rpc on the current process.
  """
  with _rpc_init_lock:
    if rpc_is_initialized() is True:
      return
    if rpc_is_initialized() is None:
      raise RuntimeError("'init_rpc': Try to re-init rpc after shutdown.")

    ctx = get_context()
    if ctx is None:
      raise RuntimeError("'init_rpc': Distributed context has not been set.")

    options = rpc.TensorPipeRpcBackendOptions(
      # _transports=['ibv', 'uv'],
      _transports=['uv'],
      _channels=['mpt_uv', 'basic'],
      num_worker_threads=num_rpc_threads,
      rpc_timeout=rpc_timeout,
      init_method=f'tcp://{master_addr}:{master_port}'
    )
    
    rpc.init_rpc(
      name=ctx.worker_name,
      rank=ctx.global_rank,
      world_size=None if is_dynamic else ctx.global_world_size,
      rpc_backend_options=options
    )

    global _rpc_inited
    _rpc_inited = True

    global _rpc_current_group_worker_names
    global _rpc_worker_names
    _rpc_worker_names = {}
    
    if is_dynamic:
      _rpc_worker_names[DistRole.SERVER] = []
      _rpc_worker_names[DistRole.CLIENT] = [] 
     
      if ctx.is_server():
        # ensure all servers is inited
        for server_rank in range(ctx.world_size):  
          if server_rank == ctx.rank:
            _rpc_worker_names[DistRole.SERVER].append(ctx.group_name + '_' + str(server_rank))
            continue
          times = 0
          is_avail = False
          while not is_avail:
            try:
              is_avail = rpc_global_request_by_rank(server_rank, rpc.is_available)
            except:
              time.sleep(SERVER_INIT_CHECK_INTERVAL)
              logging.info(f"RETRY {times}: server {ctx.rank} waits server {server_rank}...")
            times += 1
            if times >= MAX_RETYR_TIMES:
              raise RuntimeError(f"TIMEOUT: server {ctx.rank} waits server {server_rank} timeout."
                                 f"Check if server {server_rank} is ready.")
          _rpc_worker_names[DistRole.SERVER].append(ctx.group_name + '_' + str(server_rank))
        _rpc_current_group_worker_names = set(_rpc_worker_names[DistRole.SERVER])
        return
      if ctx.is_client():
        for server_rank in range(ctx.global_rank - ctx.rank):
          times = 0
          is_avail = False
          while not is_avail:
            try:
              is_avail = rpc_global_request_by_rank(server_rank, rpc.is_available)
            except:
              time.sleep(SERVER_INIT_CHECK_INTERVAL)
              logging.info(f"RETRY {times}: client {ctx.rank} waits server {server_rank}...")
            times += 1
            if times >= MAX_RETYR_TIMES:
              raise RuntimeError(f"TIMEOUT: client {ctx.rank} waits server {server_rank} timeout."
                                 f"Check if server {server_rank} is ready.")
          server_name = rpc_global_request_by_rank(server_rank, rpc.get_worker_info).name
          _rpc_worker_names[DistRole.SERVER].append(server_name)
        _rpc_current_group_worker_names = set([ctx.group_name + '_' + str(client_rank) for client_rank in range(ctx.world_size)])
        return
    
    gathered_results = global_all_gather(
      obj=(ctx.role, ctx.world_size, ctx.rank), timeout=rpc_timeout
    )
    for worker_name, (role, role_size, role_rank) in gathered_results.items():
      worker_list = _rpc_worker_names.get(role, None)
      if worker_list is None:
        worker_list = [None for _ in range(role_size)]
      else:
        if len(worker_list) != role_size:
          raise RuntimeError(f"'init_rpc': world size of role {role} gathered "
                             f"from {worker_name} is inconsistent with others.")
      if worker_list[role_rank] is not None:
        raise RuntimeError(f"'init_rpc': try to set worker name twice with "
                           f"the same rank {role_rank} of role {role}")
      worker_list[role_rank] = worker_name
      _rpc_worker_names[role] = worker_list

    _rpc_current_group_worker_names = set(_rpc_worker_names[ctx.role])

    global_barrier(timeout=rpc_timeout) 
    # TODO(hongyi): in server-client mode, if "torch.distributed.init_process_group" follows "global_barrier", 
    # some participants may randomly hang
    time.sleep(1)