in graphlearn_torch/python/distributed/rpc.py [0:0]
def init_rpc(master_addr: str,
master_port: int,
num_rpc_threads: int = 16,
rpc_timeout: float = 180,
is_dynamic: bool = False):
r""" Initialize rpc on the current process.
"""
with _rpc_init_lock:
if rpc_is_initialized() is True:
return
if rpc_is_initialized() is None:
raise RuntimeError("'init_rpc': Try to re-init rpc after shutdown.")
ctx = get_context()
if ctx is None:
raise RuntimeError("'init_rpc': Distributed context has not been set.")
options = rpc.TensorPipeRpcBackendOptions(
# _transports=['ibv', 'uv'],
_transports=['uv'],
_channels=['mpt_uv', 'basic'],
num_worker_threads=num_rpc_threads,
rpc_timeout=rpc_timeout,
init_method=f'tcp://{master_addr}:{master_port}'
)
rpc.init_rpc(
name=ctx.worker_name,
rank=ctx.global_rank,
world_size=None if is_dynamic else ctx.global_world_size,
rpc_backend_options=options
)
global _rpc_inited
_rpc_inited = True
global _rpc_current_group_worker_names
global _rpc_worker_names
_rpc_worker_names = {}
if is_dynamic:
_rpc_worker_names[DistRole.SERVER] = []
_rpc_worker_names[DistRole.CLIENT] = []
if ctx.is_server():
# ensure all servers is inited
for server_rank in range(ctx.world_size):
if server_rank == ctx.rank:
_rpc_worker_names[DistRole.SERVER].append(ctx.group_name + '_' + str(server_rank))
continue
times = 0
is_avail = False
while not is_avail:
try:
is_avail = rpc_global_request_by_rank(server_rank, rpc.is_available)
except:
time.sleep(SERVER_INIT_CHECK_INTERVAL)
logging.info(f"RETRY {times}: server {ctx.rank} waits server {server_rank}...")
times += 1
if times >= MAX_RETYR_TIMES:
raise RuntimeError(f"TIMEOUT: server {ctx.rank} waits server {server_rank} timeout."
f"Check if server {server_rank} is ready.")
_rpc_worker_names[DistRole.SERVER].append(ctx.group_name + '_' + str(server_rank))
_rpc_current_group_worker_names = set(_rpc_worker_names[DistRole.SERVER])
return
if ctx.is_client():
for server_rank in range(ctx.global_rank - ctx.rank):
times = 0
is_avail = False
while not is_avail:
try:
is_avail = rpc_global_request_by_rank(server_rank, rpc.is_available)
except:
time.sleep(SERVER_INIT_CHECK_INTERVAL)
logging.info(f"RETRY {times}: client {ctx.rank} waits server {server_rank}...")
times += 1
if times >= MAX_RETYR_TIMES:
raise RuntimeError(f"TIMEOUT: client {ctx.rank} waits server {server_rank} timeout."
f"Check if server {server_rank} is ready.")
server_name = rpc_global_request_by_rank(server_rank, rpc.get_worker_info).name
_rpc_worker_names[DistRole.SERVER].append(server_name)
_rpc_current_group_worker_names = set([ctx.group_name + '_' + str(client_rank) for client_rank in range(ctx.world_size)])
return
gathered_results = global_all_gather(
obj=(ctx.role, ctx.world_size, ctx.rank), timeout=rpc_timeout
)
for worker_name, (role, role_size, role_rank) in gathered_results.items():
worker_list = _rpc_worker_names.get(role, None)
if worker_list is None:
worker_list = [None for _ in range(role_size)]
else:
if len(worker_list) != role_size:
raise RuntimeError(f"'init_rpc': world size of role {role} gathered "
f"from {worker_name} is inconsistent with others.")
if worker_list[role_rank] is not None:
raise RuntimeError(f"'init_rpc': try to set worker name twice with "
f"the same rank {role_rank} of role {role}")
worker_list[role_rank] = worker_name
_rpc_worker_names[role] = worker_list
_rpc_current_group_worker_names = set(_rpc_worker_names[ctx.role])
global_barrier(timeout=rpc_timeout)
# TODO(hongyi): in server-client mode, if "torch.distributed.init_process_group" follows "global_barrier",
# some participants may randomly hang
time.sleep(1)