graphlearn_torch/python/distributed/dist_context.py (141 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from enum import Enum from typing import Optional, List class DistRole(Enum): r""" Role types for distributed context groups. """ WORKER = 1 # As a worker in a distributed worker group (non-server mode) SERVER = 2 # As a server in a distributed server group (server-client mode) CLIENT = 3 # As a client in a distributed client group (server-client mode) _DEFAULT_WORKER_GROUP = '_default_worker' _DEFAULT_SERVER_GROUP = '_default_server' _DEFAULT_CLIENT_GROUP = '_default_client' class DistContext(object): r""" Distributed context info of the current process. Args: role (DistRole): The role type of the current context group. group_name (str): A unique name of the current role group. world_size (int): The number of processes in the current role group. rank (int): The current process rank within the current role group. global_world_size (int): The total number of processes in all role groups. global_rank (int): The current process rank within all role groups. """ def __init__(self, role: DistRole, group_name: str, world_size: int, rank: int, global_world_size: int, global_rank: int): assert world_size > 0 and rank in range(world_size) assert global_world_size > 0 and global_rank in range(global_world_size) assert world_size <= global_world_size self.role = role self.group_name = group_name self.world_size = world_size self.rank = rank self.global_world_size = global_world_size self.global_rank = global_rank def __repr__(self) -> str: cls = self.__class__.__name__ info = [] for key, value in self.__dict__.items(): info.append(f"{key}: {value}") info = ", ".join(info) return f"{cls}({info})" def __eq__(self, obj): if not isinstance(obj, DistContext): return False for key, value in self.__dict__.items(): if value != obj.__dict__[key]: return False return True def is_worker(self) -> bool: return self.role == DistRole.WORKER def is_server(self) -> bool: return self.role == DistRole.SERVER def is_client(self) -> bool: return self.role == DistRole.CLIENT def num_servers(self) -> int: if self.role == DistRole.SERVER: return self.world_size if self.role == DistRole.CLIENT: return self.global_world_size - self.world_size return 0 def num_clients(self) -> int: if self.role == DistRole.CLIENT: return self.world_size if self.role == DistRole.SERVER: return self.global_world_size - self.world_size return 0 @property def worker_name(self) -> str: r""" Get worker name of the current process of this context. """ return f"{self.group_name}_{self.rank}" _dist_context: DistContext = None r""" Distributed context on the current process. """ _clients_to_servers: dict = None r""" A dict mapping from client rank to server ranks. int -> List[int]""" def get_context() -> DistContext: r""" Get distributed context info of the current process. """ return _dist_context def get_clients_to_servers() -> dict: r""" Get client to servers mapping. """ return _clients_to_servers def _set_worker_context(world_size: int, rank: int, group_name: Optional[str] = None): r""" Set distributed context info as a non-server worker on the current process. """ global _dist_context _dist_context = DistContext( role=DistRole.WORKER, group_name=(group_name if group_name is not None else _DEFAULT_WORKER_GROUP), world_size=world_size, rank=rank, global_world_size=world_size, global_rank=rank ) def _set_server_context(num_servers: int, server_rank: int, server_group_name: Optional[str] = None, num_clients: int = 0): r""" Set distributed context info as a server on the current process. """ assert num_servers > 0 global _dist_context _dist_context = DistContext( role=DistRole.SERVER, group_name=(server_group_name if server_group_name is not None else _DEFAULT_SERVER_GROUP), world_size=num_servers, rank=server_rank, global_world_size=num_servers+num_clients, global_rank=server_rank ) def _set_client_context(num_servers: int, num_clients: int, client_rank: int, client_group_name: Optional[str] = None): r""" Set distributed context info as a client on the current process. """ assert num_servers > 0 and num_clients > 0 global _dist_context _dist_context = DistContext( role=DistRole.CLIENT, group_name=(client_group_name if client_group_name is not None else _DEFAULT_CLIENT_GROUP), world_size=num_clients, rank=client_rank, global_world_size=num_servers+num_clients, global_rank=num_servers+client_rank ) assign_server_by_order() def assign_server_by_order(): r"""Assign servers to each client in turn. e.g. 2 clients and 4 servers, then the assignment is: {0: [0, 1], 1: [2, 3]}, 5 clients and 2 servers, then the assignment is: {0: [0], 1: [1], 2: [0], 3: [1], 4: [0]}.""" ctx = get_context() assert ctx is not None and ctx.is_client() client_num, server_num = ctx.world_size, ctx.global_world_size - ctx.world_size global _clients_to_servers _clients_to_servers = {} cur_server = 0 for i in range(client_num): if i not in _clients_to_servers: _clients_to_servers[i] = [] for j in range(server_num // client_num): _clients_to_servers[i].append(cur_server) cur_server = (cur_server + 1) % server_num if i < server_num % client_num: _clients_to_servers[i].append(cur_server) cur_server = (cur_server + 1) % server_num if len(_clients_to_servers[i]) == 0: _clients_to_servers[i].append(cur_server) cur_server = (cur_server + 1) % server_num return _clients_to_servers[ctx.rank] def init_worker_group(world_size: int, rank: int, group_name: Optional[str] = None): r""" Initialize a simple worker group on the current process, this method should be called only in a non-server distribution mode with a group of parallel workers. Args: world_size (int): Number of all processes participating in the distributed worker group. rank (int): Rank of the current process withing the distributed group (it should be a number between 0 and ``world_size``-1). group_name (str): A unique name of the distributed group that current process belongs to. If set to ``None``, a default name will be used. """ _set_worker_context(world_size, rank, group_name)