graphlearn_torch/python/distributed/dist_options.py (141 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
from typing import List, Optional, Union, Literal
import torch
from ..utils import assign_device
from .dist_context import DistContext, assign_server_by_order
class _BasicDistSamplingWorkerOptions(object):
r""" Basic options to launch distributed sampling workers.
Args:
num_workers (int): How many workers to use for distributed neighbor
sampling of the current process, must be same for each process of
the current context group. (default: ``1``).
worker_devices (torch.device or List[torch.device], optional): List of
devices assgined to workers of this group. If set to ``None``, the
devices to use will be automatically assigned (the cuda device will be
preferred if available). (default: ``None``).
worker_concurrency (int): The max sampling concurrency with different
seeds batches for each sampling worker, which should not exceed 32.
(default: ``1``).
master_addr (str, optional): Master address for rpc initialization across
all sampling workers. the environment varaible ``MASTER_ADDR`` will be
used if set to ``None``. (default: ``None``).
master_port (str or int, optional): Master port for rpc initialization
across all sampling workers. If set to ``None``, in order to avoid
conflicts with master port already used by other modules (e.g., the
method ``init_process_group`` of ``torch.distributed``), the value of
environment varaible ``MASTER_PORT`` will be increased by one as the
real rpc port for sampling workers. Otherwise, the provided port should
be guaranteed to avoid such conflicts. (default: ``None``).
num_rpc_threads (int, optional): Number of threads used for rpc agent on
each sampling worker. If set to ``None``, the number of rpc threads to
use will be specified according to the actual workload, but will not
exceed 16. (default: ``None``).
rpc_timeout (float): The timeout in seconds for all rpc requests during
distributed sampling and feature collection. (default: ``180``).
"""
def __init__(self,
num_workers: int = 1,
worker_devices: Optional[List[torch.device]] = None,
worker_concurrency: int = 1,
master_addr: Optional[str] = None,
master_port: Optional[Union[str, int]] = None,
num_rpc_threads: Optional[int] = None,
rpc_timeout: float = 180):
self.num_workers = num_workers
# Not sure yet, will be calculated later.
self.worker_world_size = None
self.worker_ranks = None
if worker_devices is None:
self.worker_devices = None
elif isinstance(worker_devices, list) or isinstance(worker_devices, tuple):
assert len(worker_devices) == self.num_workers
self.worker_devices = list(worker_devices)
else:
self.worker_devices = [worker_devices] * self.num_workers
# Worker concurrency should not exceed 32.
self.worker_concurrency = max(worker_concurrency, 1)
self.worker_concurrency = min(self.worker_concurrency, 32)
if master_addr is not None:
self.master_addr = str(master_addr)
elif os.environ.get('MASTER_ADDR') is not None:
self.master_addr = os.environ['MASTER_ADDR']
else:
raise ValueError(f"'{self.__class__.__name__}': missing master address "
"for rpc communication, try to provide it or set it "
"with environment variable 'MASTER_ADDR'")
if master_port is not None:
self.master_port = int(master_port)
elif os.environ.get('MASTER_PORT') is not None:
self.master_port = int(os.environ['MASTER_PORT']) + 1
else:
raise ValueError(f"'{self.__class__.__name__}': missing master port "
"for rpc communication, try to provide it or set it "
"with environment variable 'MASTER_ADDR'")
self.num_rpc_threads = num_rpc_threads
if self.num_rpc_threads is not None:
assert self.num_rpc_threads > 0
self.rpc_timeout = rpc_timeout
def _set_worker_ranks(self, current_ctx: DistContext):
self.worker_world_size = current_ctx.world_size * self.num_workers
self.worker_ranks = [
current_ctx.rank * self.num_workers + i
for i in range(self.num_workers)
]
def _assign_worker_devices(self):
if self.worker_devices is not None:
return
self.worker_devices = [assign_device() for _ in range(self.num_workers)]
class CollocatedDistSamplingWorkerOptions(_BasicDistSamplingWorkerOptions):
r""" Options for launching a single distributed sampling worker collocated
with the current process.
Args:
master_addr (str, optional): Master address for rpc initialization across
all sampling workers. (default: ``None``).
master_port (str or int, optional): Master port for rpc initialization
across all sampling workers. (default: ``None``).
num_rpc_threads (int, optional): Number of threads used for rpc agent on
each sampling worker. (default: ``None``).
rpc_timeout (float): The timeout in seconds for rpc requests.
(default: ``180``).
use_all2all (bool): Whether use all2all to collect distributed node/edge
feature instead of through p2p rpc. (deafult: ``False``).
Please ref to ``_BasicDistSamplingWorkerOptions`` for more detailed comments
of related input arguments.
"""
def __init__(self,
master_addr: Optional[str] = None,
master_port: Optional[Union[str, int]] = None,
num_rpc_threads: Optional[int] = None,
rpc_timeout: float = 180,
use_all2all: bool = False):
super().__init__(1, None, 1, master_addr, master_port,
num_rpc_threads, rpc_timeout)
self.use_all2all = use_all2all
class MpDistSamplingWorkerOptions(_BasicDistSamplingWorkerOptions):
r""" Options for launching distributed sampling workers with multiprocessing.
Note that if ``MpDistWorkerOptions`` is used, all sampling workers will be
launched on spawned subprocesses by ``torch.multiprocessing``. Thus, a
share-memory based channel should be created for message passing of sampled
results, which are produced by those multiprocessing sampling workers and
consumed by the current process.
Args:
num_workers (int): How many workers to use (subprocesses to spwan) for
distributed neighbor sampling of the current process. (default: ``1``).
worker_devices (torch.device or List[torch.device], optional): List of
devices assgined to workers of this group. (default: ``None``).
worker_concurrency (int): The max sampling concurrency for each sampling
worker. (default: ``4``).
master_addr (str, optional): Master address for rpc initialization across
all sampling workers. (default: ``None``).
master_port (str or int, optional): Master port for rpc initialization
across all sampling workers. (default: ``None``).
num_rpc_threads (int, optional): Number of threads used for rpc agent on
each sampling worker. (default: ``None``).
rpc_timeout (float): The timeout in seconds for rpc requests.
(default: ``180``).
channel_size (int or str): The shared-memory buffer size (bytes) allocated
for the channel. The number of ``num_workers * 64MB`` will be used if set
to ``None``. (default: ``None``).
pin_memory (bool): Set to ``True`` to register the underlying shared memory
for cuda, which will achieve better performance if you want to copy
loaded data from channel to cuda device. (default: ``False``).
use_all2all (bool): Whether use all2all to collect distributed node/edge
feature instead of through p2p rpc. (deafult: ``False``).
Please ref to ``_BasicDistSamplingWorkerOptions`` for more detailed comments
of related input arguments.
"""
def __init__(self,
num_workers: int = 1,
worker_devices: Optional[List[torch.device]] = None,
worker_concurrency: int = 4,
master_addr: Optional[str] = None,
master_port: Optional[Union[str, int]] = None,
num_rpc_threads: Optional[int] = None,
rpc_timeout: float = 180,
channel_size: Optional[Union[int, str]] = None,
pin_memory: bool = False,
use_all2all: bool = False):
super().__init__(num_workers, worker_devices, worker_concurrency,
master_addr, master_port, num_rpc_threads, rpc_timeout)
self.channel_capacity = self.num_workers * self.worker_concurrency
if channel_size is None:
self.channel_size = f'{self.num_workers * 64}MB'
else:
self.channel_size = channel_size
self.pin_memory = pin_memory
self.use_all2all = use_all2all
class RemoteDistSamplingWorkerOptions(_BasicDistSamplingWorkerOptions):
r""" Options for launching distributed sampling workers on remote servers.
Note that if ``RemoteDistSamplingWorkerOptions`` is used, all sampling
workers will be launched on remote servers. Thus, a cross-machine based
channel will be created for message passing of sampled results, which are
produced by those remote sampling workers and consumed by the current process.
Args:
server_rank (int or List[int], optional): The rank of server to launch
sampling workers, can be multiple. If set to ``None``, it will be
automatically assigned. (default: ``None``).
num_workers (int): How many workers to launch on the remote server for
distributed neighbor sampling of the current process. (default: ``1``).
worker_devices (torch.device or List[torch.device], optional): List of
devices assgined to workers of this group. (default: ``None``).
worker_concurrency (int): The max sampling concurrency for each sampling
worker. (default: ``4``).
master_addr (str, optional): Master address for rpc initialization across
all sampling workers. (default: ``None``).
master_port (str or int, optional): Master port for rpc initialization
across all sampling workers. (default: ``None``).
num_rpc_threads (int, optional): Number of threads used for rpc agent on
each sampling worker. (default: ``None``).
rpc_timeout (float): The timeout in seconds for rpc requests.
(default: ``180``).
buffer_size (int or str): The size (bytes) allocated for the server-side
buffer. The number of ``num_workers * 64MB`` will be used if set to
``None``. (default: ``None``).
prefetch_size (int): The max prefetched sampled messages for consuming on
the client side. (default: ``4``).
glt_graph: Used in GraphScope side to get parameters. (default: ``None``).
workload_type: Used in GraphScope side, indicates the type of option. This
field must be set when ``workload_type`` is not None. (default: ``None``).
"""
def __init__(self,
server_rank: Optional[Union[int, List[int]]] = None,
num_workers: int = 1,
worker_devices: Optional[List[torch.device]] = None,
worker_concurrency: int = 4,
master_addr: Optional[str] = None,
master_port: Optional[Union[str, int]] = None,
num_rpc_threads: Optional[int] = None,
rpc_timeout: float = 180,
buffer_size: Optional[Union[int, str]] = None,
prefetch_size: int = 4,
worker_key: str = None,
glt_graph = None,
workload_type: Optional[Literal['train', 'validate', 'test']] = None,
use_all2all: bool = False):
# glt_graph is used in GraphScope side to get parameters
if glt_graph:
if not workload_type:
raise ValueError(f"'{self.__class__.__name__}': missing workload_type ")
master_addr = glt_graph.master_addr
if workload_type == 'train':
master_port = glt_graph.train_loader_master_port
elif workload_type == 'validate':
master_port = glt_graph.val_loader_master_port
elif workload_type == 'test':
master_port = glt_graph.test_loader_master_port
worker_key = str(master_port)
super().__init__(num_workers, worker_devices, worker_concurrency,
master_addr, master_port, num_rpc_threads, rpc_timeout)
if server_rank is not None:
self.server_rank = server_rank
else:
self.server_rank = assign_server_by_order()
self.buffer_capacity = self.num_workers * self.worker_concurrency
if buffer_size is None:
self.buffer_size = f'{self.num_workers * 64}MB'
else:
self.buffer_size = buffer_size
self.prefetch_size = prefetch_size
if self.prefetch_size > self.buffer_capacity:
raise ValueError(f"'{self.__class__.__name__}': the prefetch count "
f"{self.prefetch_size} exceeds the buffer capacity "
f"{self.buffer_capacity}")
self.worker_key = worker_key
self.use_all2all = use_all2all
AllDistSamplingWorkerOptions = Union[
CollocatedDistSamplingWorkerOptions,
MpDistSamplingWorkerOptions,
RemoteDistSamplingWorkerOptions
]