graphlearn_torch/python/distributed/dist_server.py (183 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
import time
import threading
from typing import Dict, Optional, Union
import warnings
import torch
from ..partition import PartitionBook
from ..channel import ShmChannel, QueueTimeoutError
from ..sampler import NodeSamplerInput, EdgeSamplerInput, SamplingConfig, RemoteSamplerInput
from .dist_context import get_context, _set_server_context
from .dist_dataset import DistDataset
from .dist_options import RemoteDistSamplingWorkerOptions
from .dist_sampling_producer import DistMpSamplingProducer
from .rpc import barrier, init_rpc, shutdown_rpc
SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0
r""" Interval (in seconds) to check exit status of server.
"""
class DistServer(object):
r""" A server that supports launching remote sampling workers for
training clients.
Note that this server is enabled only when the distribution mode is a
server-client framework, and the graph and feature store will be partitioned
and managed by all server nodes.
Args:
dataset (DistDataset): The ``DistDataset`` object of a partition of graph
data and feature data, along with distributed patition books.
"""
def __init__(self, dataset: DistDataset):
self.dataset = dataset
self._lock = threading.RLock()
self._exit = False
self._cur_producer_idx = 0 # auto incremental index (same as producer count)
# The mapping from the key in worker options (such as 'train', 'test')
# to producer id
self._worker_key2producer_id: Dict[str, int] = {}
self._producer_pool: Dict[int, DistMpSamplingProducer] = {}
self._msg_buffer_pool: Dict[int, ShmChannel] = {}
self._epoch: Dict[int, int] = {} # last epoch for the producer
def shutdown(self):
for producer_id in list(self._producer_pool.keys()):
self.destroy_sampling_producer(producer_id)
assert len(self._producer_pool) == 0
assert len(self._msg_buffer_pool) == 0
def wait_for_exit(self):
r""" Block until the exit flag been set to ``True``.
"""
while not self._exit:
time.sleep(SERVER_EXIT_STATUS_CHECK_INTERVAL)
def exit(self):
r""" Set the exit flag to ``True``.
"""
self._exit = True
return self._exit
def get_dataset_meta(self):
r""" Get the meta info of the distributed dataset managed by the current
server, including partition info and graph types.
"""
return self.dataset.num_partitions, self.dataset.partition_idx, \
self.dataset.get_node_types(), self.dataset.get_edge_types()
def get_node_partition_id(self, node_type, index):
if isinstance(self.dataset.node_pb, PartitionBook):
partition_id = self.dataset.node_pb[index]
return partition_id
elif isinstance(self.dataset.node_pb, Dict):
partition_id = self.dataset.node_pb[node_type][index]
return partition_id
return None
def get_node_feature(self, node_type, index):
feature = self.dataset.get_node_feature(node_type)
return feature[index].cpu()
def get_tensor_size(self, node_type):
feature = self.dataset.get_node_feature(node_type)
return feature.shape
def get_node_label(self, node_type, index):
label = self.dataset.get_node_label(node_type)
return label[index]
def get_edge_index(self, edge_type, layout):
graph = self.dataset.get_graph(edge_type)
row = None
col = None
result = None
if layout == 'coo':
row, col, _, _ = graph.topo.to_coo()
result = (row, col)
else:
raise ValueError(f"Invalid layout {layout}")
return result
def get_edge_size(self, edge_type, layout):
graph = self.dataset.get_graph(edge_type)
if layout == 'coo':
row_count = graph.row_count
col_count = graph.col_count
else:
raise ValueError(f"Invalid layout {layout}")
return (row_count, col_count)
def create_sampling_producer(
self,
sampler_input: Union[NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput],
sampling_config: SamplingConfig,
worker_options: RemoteDistSamplingWorkerOptions,
) -> int:
r""" Create and initialize an instance of ``DistSamplingProducer`` with
a group of subprocesses for distributed sampling.
Args:
sampler_input (NodeSamplerInput or EdgeSamplerInput): The input data
for sampling.
sampling_config (SamplingConfig): Configuration of sampling meta info.
worker_options (RemoteDistSamplingWorkerOptions): Options for launching
remote sampling workers by this server.
Returns:
A unique id of created sampling producer on this server.
"""
if isinstance(sampler_input, RemoteSamplerInput):
sampler_input = sampler_input.to_local_sampler_input(dataset=self.dataset)
with self._lock:
producer_id = self._worker_key2producer_id.get(worker_options.worker_key)
if producer_id is None:
producer_id = self._cur_producer_idx
self._worker_key2producer_id[worker_options.worker_key] = producer_id
self._cur_producer_idx += 1
buffer = ShmChannel(
worker_options.buffer_capacity, worker_options.buffer_size
)
producer = DistMpSamplingProducer(
self.dataset, sampler_input, sampling_config, worker_options, buffer
)
producer.init()
self._producer_pool[producer_id] = producer
self._msg_buffer_pool[producer_id] = buffer
self._epoch[producer_id] = -1
return producer_id
def destroy_sampling_producer(self, producer_id: int):
r""" Shutdown and destroy a sampling producer managed by this server with
its producer id.
"""
with self._lock:
producer = self._producer_pool.get(producer_id, None)
if producer is not None:
producer.shutdown()
self._producer_pool.pop(producer_id)
self._msg_buffer_pool.pop(producer_id)
self._epoch.pop(producer_id)
def start_new_epoch_sampling(self, producer_id: int, epoch: int):
r""" Start a new epoch sampling tasks for a specific sampling producer
with its producer id.
"""
with self._lock:
cur_epoch = self._epoch[producer_id]
if cur_epoch < epoch:
self._epoch[producer_id] = epoch
producer = self._producer_pool.get(producer_id, None)
if producer is not None:
producer.produce_all()
def fetch_one_sampled_message(self, producer_id: int):
r""" Fetch a sampled message from the buffer of a specific sampling
producer with its producer id.
"""
producer = self._producer_pool.get(producer_id, None)
if producer is None:
warnings.warn('invalid producer_id {producer_id}')
return None, False
if producer.is_all_sampling_completed_and_consumed():
return None, True
buffer = self._msg_buffer_pool.get(producer_id, None)
while True:
try:
msg = buffer.recv(timeout_ms=500)
return msg, False
except QueueTimeoutError as e:
if producer.is_all_sampling_completed():
return None, True
_dist_server: DistServer = None
r""" ``DistServer`` instance of the current process.
"""
def get_server() -> DistServer:
r""" Get the ``DistServer`` instance on the current process.
"""
return _dist_server
def init_server(num_servers: int, server_rank: int, dataset: DistDataset,
master_addr: str, master_port: int, num_clients: int = 0,
num_rpc_threads: int = 16, request_timeout: int = 180,
server_group_name: Optional[str] = None, is_dynamic: bool = False):
r""" Initialize the current process as a server and establish connections
with all other servers and clients. Note that this method should be called
only in the server-client distribution mode.
Args:
num_servers (int): Number of processes participating in the server group.
server_rank (int): Rank of the current process withing the server group (it
should be a number between 0 and ``num_servers``-1).
dataset (DistDataset): The ``DistDataset`` object of a partition of graph
data and feature data, along with distributed patition book info.
master_addr (str): The master TCP address for RPC connection between all
servers and clients, the value of this parameter should be same for all
servers and clients.
master_port (int): The master TCP port for RPC connection between all
servers and clients, the value of this parameter should be same for all
servers and clients.
num_clients (int): Number of processes participating in the client group.
if ``is_dynamic`` is ``True``, this parameter will be ignored.
num_rpc_threads (int): The number of RPC worker threads used for the
current server to respond remote requests. (Default: ``16``).
request_timeout (int): The max timeout seconds for remote requests,
otherwise an exception will be raised. (Default: ``16``).
server_group_name (str): A unique name of the server group that current
process belongs to. If set to ``None``, a default name will be used.
(Default: ``None``).
is_dynamic (bool): Whether the world size is dynamic. (Default: ``False``).
"""
if server_group_name:
server_group_name = server_group_name.replace('-', '_')
_set_server_context(num_servers, server_rank, server_group_name, num_clients)
global _dist_server
_dist_server = DistServer(dataset=dataset)
init_rpc(master_addr, master_port, num_rpc_threads, request_timeout, is_dynamic=is_dynamic)
def wait_and_shutdown_server():
r""" Block until all client have been shutdowned, and further shutdown the
server on the current process and destroy all RPC connections.
"""
current_context = get_context()
if current_context is None:
logging.warning("'wait_and_shutdown_server': try to shutdown server when "
"the current process has not been initialized as a server.")
return
if not current_context.is_server():
raise RuntimeError(f"'wait_and_shutdown_server': role type of "
f"the current process context is not a server, "
f"got {current_context.role}.")
global _dist_server
_dist_server.wait_for_exit()
_dist_server.shutdown()
_dist_server = None
barrier()
shutdown_rpc()
def _call_func_on_server(func, *args, **kwargs):
r""" A callee entry for remote requests on the server side.
"""
if not callable(func):
logging.warning(f"'_call_func_on_server': receive a non-callable "
f"function target {func}")
return None
server = get_server()
if hasattr(server, func.__name__):
return func(server, *args, **kwargs)
return func(*args, **kwargs)