graphlearn_torch/python/distributed/dist_client.py (44 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import logging from typing import Optional from .dist_context import DistRole, get_context, _set_client_context from .dist_server import DistServer, _call_func_on_server from .rpc import init_rpc, shutdown_rpc, rpc_global_request_async, barrier def init_client(num_servers: int, num_clients: int, client_rank: int, master_addr: str, master_port: int, num_rpc_threads: int = 4, client_group_name: Optional[str] = None, is_dynamic: bool = False): r""" Initialize the current process as a client and establish connections with all other servers and clients. Note that this method should be called only in the server-client distribution mode. Args: num_servers (int): Number of processes participating in the server group. num_clients (int): Number of processes participating in the client group. client_rank (int): Rank of the current process withing the client group (it should be a number between 0 and ``num_clients``-1). master_addr (str): The master TCP address for RPC connection between all servers and clients, the value of this parameter should be same for all servers and clients. master_port (int): The master TCP port for RPC connection between all servers and clients, the value of this parameter should be same for all servers and clients. num_rpc_threads (int): The number of RPC worker threads used for the current client. (Default: ``4``). client_group_name (str): A unique name of the client group that current process belongs to. If set to ``None``, a default name will be used. (Default: ``None``). is_dynamic (bool): Whether the world size is dynamic. (Default: ``False``). """ if client_group_name: client_group_name = client_group_name.replace('-', '_') _set_client_context(num_servers, num_clients, client_rank, client_group_name) # Note that a client RPC agent will never remote requests, thus set the # number of rpc threads to ``1`` is enough. init_rpc(master_addr, master_port, num_rpc_threads=num_rpc_threads, is_dynamic=is_dynamic) def shutdown_client(): r""" Shutdown the client on the current process, notify other servers to exit, and destroy all connections. """ current_context = get_context() if current_context is None: logging.warning("'shutdown_client': try to shutdown client when the " "current process has not been initialized as a client.") return if not current_context.is_client(): raise RuntimeError(f"'shutdown_client': role type of the current process " f"context is not a client, got {current_context.role}.") # step 1: synchronize with all other clients. barrier() # step 2: use client-0 to notify all servers to exit after all clients # have reached here. current_context = get_context() if current_context.rank == 0: for server_rank in range(current_context.num_servers()): exit_status = request_server(server_rank, DistServer.exit) assert exit_status is True, f"Failed to exit server {server_rank}" # step 3: shutdown rpc across all servers and clients. shutdown_rpc() def async_request_server(server_rank: int, func, *args, **kwargs): r""" The entry to perform an asynchronous request on a remote server, calling on the client side. """ args = [func] + list(args) return rpc_global_request_async( target_role=DistRole.SERVER, role_rank=server_rank, func=_call_func_on_server, args=args, kwargs=kwargs, ) def request_server(server_rank: int, func, *args, **kwargs): r""" The entry to perform a synchronous request on a remote server, calling on the client side. """ fut = async_request_server(server_rank, func, *args, **kwargs) return fut.wait()