graphlearn_torch/python/distributed/dist_loader.py (328 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from typing import List, Optional, Union import concurrent import torch from torch_geometric.data import Data, HeteroData from ..channel import SampleMessage, ShmChannel, RemoteReceivingChannel from ..loader import to_data, to_hetero_data from ..sampler import ( NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput, SamplerOutput, HeteroSamplerOutput, SamplingConfig, SamplingType ) from ..typing import (NodeType, EdgeType, as_str, reverse_edge_type) from ..utils import get_available_device, ensure_device, python_exit_status from .dist_client import request_server from .dist_context import get_context from .dist_dataset import DistDataset from .dist_options import ( CollocatedDistSamplingWorkerOptions, MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions, AllDistSamplingWorkerOptions, ) from .dist_sampling_producer import ( DistMpSamplingProducer, DistCollocatedSamplingProducer ) from .dist_server import DistServer from .rpc import rpc_is_initialized class DistLoader(object): r""" A generic data loader base that performs distributed sampling, which allows mini-batch training of GNNs on large-scale graphs when full-batch training is not feasible. This loader supports launching a collocated sampling worker on the current process, or launching separate sampling workers on the spawned subprocesses or remote server nodes. When using the separate sampling mode, a worker group including the information of separate sampling workers should be provided. Note that the separate sampling mode supports asynchronous and concurrent sampling on each separate worker, which will achieve better performance and is recommended to use. If you want to use a collocated sampling worker, all sampling for each seed batch will be blocking and synchronous. When launching a collocated sampling worker or some multiprocessing sampling workers (on spwaned subprocesses), the distribution mode must be non-server and only contains a group of parallel worker processes, this means that the graph and feature store should be partitioned among all those parallel worker processes and managed by them, sampling and training tasks will run on each worker process at the same time. Otherwise, when launching some remote sampling workers, the distribution mode must be a server-client framework, which contains a group of server workers and a group of client workers, the graph and feature store will be partitioned and managed by all server workers. All client workers are responsible for training tasks and launch some workers on remote servers to perform sampling tasks, the sampled results will be consumed by client workers with a remote message channel. Args: data (DistDataset, optional): The ``DistDataset`` object of a partition of graph data and feature data, along with distributed patition books. The input dataset must be provided in non-server distribution mode. input_data (NodeSamplerInput or EdgeSamplerInput or RemoteSamplerInput): The input data for which neighbors or subgraphs are sampled to create mini-batches. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node indices. sampling_config (SamplingConfig): The Configuration info for sampling. to_device (torch.device, optional): The target device that the sampled results should be copied to. If set to ``None``, the current cuda device (got by ``torch.cuda.current_device``) will be used if available, otherwise, the cpu device will be used. (default: ``None``). worker_options (optional): The options for launching sampling workers. (1) If set to ``None`` or provided with a ``CollocatedDistWorkerOptions`` object, a single collocated sampler will be launched on the current process, while the separate sampling mode will be disabled . (2) If provided with a ``MpDistWorkerOptions`` object, the sampling workers will be launched on spawned subprocesses, and a share-memory based channel will be created for sample message passing from multiprocessing workers to the current loader. (3) If provided with a ``RemoteDistWorkerOptions`` object, the sampling workers will be launched on remote sampling server nodes, and a remote channel will be created for cross-machine message passing. (default: ``None``). """ def __init__( self, data: Optional[DistDataset], input_data: Union[NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput, List[RemoteSamplerInput]], sampling_config: SamplingConfig, to_device: Optional[torch.device] = None, worker_options: Optional[AllDistSamplingWorkerOptions] = None ): self.data = data self.input_data = input_data self.sampling_type = sampling_config.sampling_type self.num_neighbors = sampling_config.num_neighbors self.batch_size = sampling_config.batch_size self.shuffle = sampling_config.shuffle self.drop_last = sampling_config.drop_last self.with_edge = sampling_config.with_edge self.with_weight = sampling_config.with_weight self.collect_features = sampling_config.collect_features self.edge_dir = sampling_config.edge_dir self.sampling_config = sampling_config self.to_device = get_available_device(to_device) self.worker_options = worker_options self._shutdowned = False if self.worker_options is None: self.worker_options = CollocatedDistSamplingWorkerOptions() self._is_collocated_worker = isinstance( self.worker_options, CollocatedDistSamplingWorkerOptions ) self._is_mp_worker = isinstance( self.worker_options, MpDistSamplingWorkerOptions ) self._is_remote_worker = isinstance( self.worker_options, RemoteDistSamplingWorkerOptions ) if self.data is not None: self.num_data_partitions = self.data.num_partitions self.data_partition_idx = self.data.partition_idx self._set_ntypes_and_etypes( self.data.get_node_types(), self.data.get_edge_types() ) self._num_recv = 0 self._epoch = 0 current_ctx = get_context() if current_ctx is None: raise RuntimeError( f"'{self.__class__.__name__}': the distributed " f"context of has not been initialized." ) if self._is_remote_worker: if not current_ctx.is_client(): raise RuntimeError( f"'{self.__class__.__name__}': `DistNeighborLoader` " f"must be used on a client worker process." ) self._num_expected = float( 'inf' ) # for remote worker, end of epoch is determined by server # Launch remote sampling workers self._with_channel = True self._server_rank_list = self.worker_options.server_rank \ if isinstance(self.worker_options.server_rank, List) else [self.worker_options.server_rank] self._input_data_list = self.input_data \ if isinstance(self.input_data, List) else [self.input_data] self._input_type = self._input_data_list[0].input_type self.num_data_partitions, self.data_partition_idx, ntypes, etypes = \ request_server(self._server_rank_list[0], DistServer.get_dataset_meta) self._set_ntypes_and_etypes(ntypes, etypes) self._producer_id_list = [] futures = [] for input_data in self._input_data_list: if not isinstance(input_data, RemoteSamplerInput): input_data = input_data.to(torch.device('cpu')) with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(request_server, server_rank, DistServer.create_sampling_producer, input_data, self.sampling_config, self.worker_options) \ for server_rank, input_data in zip(self._server_rank_list, self._input_data_list)] for future in futures: producer_id = future.result() self._producer_id_list.append(producer_id) self._channel = RemoteReceivingChannel( self._server_rank_list, self._producer_id_list, self.worker_options.prefetch_size ) else: self._input_len = len(self.input_data) self._input_type = self.input_data.input_type self._num_expected = self._input_len // self.batch_size if not self.drop_last and self._input_len % self.batch_size != 0: self._num_expected += 1 if self._is_collocated_worker: if not current_ctx.is_worker(): raise RuntimeError( f"'{self.__class__.__name__}': only supports " f"launching a collocated sampler with a non-server " f"distribution mode, current role of distributed " f"context is {current_ctx.role}." ) if self.data is None: raise ValueError( f"'{self.__class__.__name__}': missing input dataset " f"when launching a collocated sampler." ) # Launch collocated producer self._with_channel = False self._collocated_producer = DistCollocatedSamplingProducer( self.data, self.input_data, self.sampling_config, self.worker_options, self.to_device ) self._collocated_producer.init() elif self._is_mp_worker: if not current_ctx.is_worker(): raise RuntimeError( f"'{self.__class__.__name__}': only supports " f"launching multiprocessing sampling workers with " f"a non-server distribution mode, current role of " f"distributed context is {current_ctx.role}." ) if self.data is None: raise ValueError( f"'{self.__class__.__name__}': missing input dataset " f"when launching multiprocessing sampling workers." ) # Launch multiprocessing sampling workers self._with_channel = True self.worker_options._set_worker_ranks(current_ctx) self._channel = ShmChannel( self.worker_options.channel_capacity, self.worker_options.channel_size ) if self.worker_options.pin_memory: self._channel.pin_memory() self._mp_producer = DistMpSamplingProducer( self.data, self.input_data, self.sampling_config, self.worker_options, self._channel ) self._mp_producer.init() else: raise ValueError( f"'{self.__class__.__name__}': found invalid " f"worker options type '{type(worker_options)}'" ) def __del__(self): if python_exit_status is True or python_exit_status is None: return self.shutdown() def shutdown(self): if self._shutdowned: return if self._is_collocated_worker: self._collocated_producer.shutdown() elif self._is_mp_worker: self._mp_producer.shutdown() else: if rpc_is_initialized() is True: for server_rank, producer_id in zip(self._server_rank_list, self._producer_id_list): request_server( server_rank, DistServer.destroy_sampling_producer, producer_id ) self._shutdowned = True def __next__(self): if self._num_recv == self._num_expected: raise StopIteration if self._with_channel: msg = self._channel.recv() else: msg = self._collocated_producer.sample() result = self._collate_fn(msg) self._num_recv += 1 return result def __iter__(self): self._num_recv = 0 if self._is_collocated_worker: self._collocated_producer.reset() elif self._is_mp_worker: self._mp_producer.produce_all() else: for server_rank, producer_id in zip(self._server_rank_list, self._producer_id_list): request_server( server_rank, DistServer.start_new_epoch_sampling, producer_id, self._epoch ) self._channel.reset() self._epoch += 1 return self def _set_ntypes_and_etypes(self, node_types: List[NodeType], edge_types: List[EdgeType]): self._node_types = node_types self._edge_types = edge_types self._reversed_edge_types = [] self._etype_str_to_rev = {} if self._edge_types is not None: for etype in self._edge_types: rev_etype = reverse_edge_type(etype) if self.edge_dir == 'out': self._reversed_edge_types.append(rev_etype) self._etype_str_to_rev[as_str(etype)] = rev_etype elif self.edge_dir == 'in': self._reversed_edge_types.append(etype) self._etype_str_to_rev[as_str(rev_etype)] = etype def _collate_fn( self, msg: SampleMessage ) -> Union[Data, HeteroData]: r""" Collate sampled messages as PyG's Data/HeteroData """ ensure_device(self.to_device) is_hetero = bool(msg['#IS_HETERO']) # extract meta data metadata = {} for k in msg.keys(): if k.startswith('#META.'): meta_key = str(k[6:]) metadata[meta_key] = msg[k].to(self.to_device) if len(metadata) == 0: metadata = None # Heterogeneous sampling results if is_hetero: node_dict, row_dict, col_dict, edge_dict = {}, {}, {}, {} nfeat_dict, efeat_dict = {}, {} num_sampled_nodes_dict, num_sampled_edges_dict = {}, {} for ntype in self._node_types: ids_key = f'{as_str(ntype)}.ids' if ids_key in msg: node_dict[ntype] = msg[ids_key].to(self.to_device) nfeat_key = f'{as_str(ntype)}.nfeats' if nfeat_key in msg: nfeat_dict[ntype] = msg[nfeat_key].to(self.to_device) num_sampled_nodes_key = f'{as_str(ntype)}.num_sampled_nodes' if num_sampled_nodes_key in msg: num_sampled_nodes_dict[ntype] = msg[num_sampled_nodes_key] for etype_str, rev_etype in self._etype_str_to_rev.items(): rows_key = f'{etype_str}.rows' cols_key = f'{etype_str}.cols' if rows_key in msg: # The edge index should be reversed. row_dict[rev_etype] = msg[cols_key].to(self.to_device) col_dict[rev_etype] = msg[rows_key].to(self.to_device) eids_key = f'{etype_str}.eids' if eids_key in msg: edge_dict[rev_etype] = msg[eids_key].to(self.to_device) num_sampled_edges_key = f'{etype_str}.num_sampled_edges' if num_sampled_edges_key in msg: num_sampled_edges_dict[rev_etype] = msg[num_sampled_edges_key] efeat_key = f'{etype_str}.efeats' if efeat_key in msg: efeat_dict[rev_etype] = msg[efeat_key].to(self.to_device) if len(nfeat_dict) == 0: nfeat_dict = None if len(efeat_dict) == 0: efeat_dict = None if self.sampling_config.sampling_type in [SamplingType.NODE, SamplingType.SUBGRAPH]: batch_key = f'{self._input_type}.batch' if msg.get(batch_key) is not None: batch_dict = { self._input_type: msg[f'{self._input_type}.batch'].to(self.to_device) } else: batch_dict = { self._input_type: node_dict[self._input_type][:self.batch_size] } batch_labels_key = f'{self._input_type}.nlabels' if batch_labels_key in msg: batch_labels = msg[batch_labels_key].to(self.to_device) else: batch_labels = None batch_label_dict = {self._input_type: batch_labels} else: batch_dict = {} batch_label_dict = {} output = HeteroSamplerOutput(node_dict, row_dict, col_dict, edge_dict if len(edge_dict) else None, batch_dict, num_sampled_nodes=num_sampled_nodes_dict, num_sampled_edges=num_sampled_edges_dict, edge_types=self._reversed_edge_types, input_type=self._input_type, device=self.to_device, metadata=metadata) res_data = to_hetero_data( output, batch_label_dict, nfeat_dict, efeat_dict, self.edge_dir) # Homogeneous sampling results else: ids = msg['ids'].to(self.to_device) rows = msg['rows'].to(self.to_device) cols = msg['cols'].to(self.to_device) eids = msg['eids'].to(self.to_device) if 'eids' in msg else None num_sampled_nodes = msg['num_sampled_nodes'] if 'num_sampled_nodes' in msg else None num_sampled_edges = msg['num_sampled_edges'] if 'num_sampled_edges' in msg else None nfeats = msg['nfeats'].to(self.to_device) if 'nfeats' in msg else None efeats = msg['efeats'].to(self.to_device) if 'efeats' in msg else None if self.sampling_config.sampling_type in [SamplingType.NODE, SamplingType.SUBGRAPH]: if msg.get('batch') is not None: batch = msg['batch'].to(self.to_device) else: batch = ids[:self.batch_size] batch_labels = msg['nlabels'].to(self.to_device) if 'nlabels' in msg else None else: batch = None batch_labels = None # The edge index should be reversed. output = SamplerOutput(ids, cols, rows, eids, batch, num_sampled_nodes, num_sampled_edges, device=self.to_device, metadata=metadata) res_data = to_data(output, batch_labels, nfeats, efeats) return res_data