graphlearn_torch/python/distributed/dist_sampling_producer.py (285 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import queue import time, datetime from enum import Enum from typing import Optional, Union import torch import torch.multiprocessing as mp from torch._C import _set_worker_signal_handlers from torch.utils.data.dataloader import DataLoader from ..channel import ChannelBase from ..sampler import ( NodeSamplerInput, EdgeSamplerInput, SamplingType, SamplingConfig ) from ..utils import ensure_device from ..utils import seed_everything from ..distributed.dist_context import get_context from .dist_context import init_worker_group from .dist_dataset import DistDataset from .dist_neighbor_sampler import DistNeighborSampler from .dist_options import _BasicDistSamplingWorkerOptions from .rpc import init_rpc, shutdown_rpc MP_STATUS_CHECK_INTERVAL = 5.0 r""" Interval (in seconds) to check status of processes to avoid hanging in multiprocessing sampling. """ class MpCommand(Enum): r""" Enum class for multiprocessing sampling command """ SAMPLE_ALL = 0 STOP = 1 def _sampling_worker_loop(rank, data: DistDataset, sampler_input: Union[NodeSamplerInput, EdgeSamplerInput], unshuffled_index: Optional[torch.Tensor], sampling_config: SamplingConfig, worker_options: _BasicDistSamplingWorkerOptions, channel: ChannelBase, task_queue: mp.Queue, sampling_completed_worker_count: mp.Value, mp_barrier): r""" Subprocess work loop for sampling worker. """ dist_sampler = None try: init_worker_group( world_size=worker_options.worker_world_size, rank=worker_options.worker_ranks[rank], group_name='_sampling_worker_subprocess' ) if worker_options.use_all2all: torch.distributed.init_process_group( backend='gloo', timeout=datetime.timedelta(seconds=worker_options.rpc_timeout), rank=worker_options.worker_ranks[rank], world_size=worker_options.worker_world_size, init_method='tcp://{}:{}'.format(worker_options.master_addr, worker_options.master_port) ) if worker_options.num_rpc_threads is None: num_rpc_threads = min(data.num_partitions, 16) else: num_rpc_threads = worker_options.num_rpc_threads current_device = worker_options.worker_devices[rank] ensure_device(current_device) _set_worker_signal_handlers() torch.set_num_threads(num_rpc_threads + 1) init_rpc( master_addr=worker_options.master_addr, master_port=worker_options.master_port, num_rpc_threads=num_rpc_threads, rpc_timeout=worker_options.rpc_timeout ) if sampling_config.seed is not None: seed_everything(sampling_config.seed) dist_sampler = DistNeighborSampler( data, sampling_config.num_neighbors, sampling_config.with_edge, sampling_config.with_neg, sampling_config.with_weight, sampling_config.edge_dir, sampling_config.collect_features, channel, worker_options.use_all2all, worker_options.worker_concurrency, current_device, seed=sampling_config.seed ) dist_sampler.start_loop() if unshuffled_index is not None: unshuffled_index_loader = DataLoader( unshuffled_index, batch_size=sampling_config.batch_size, shuffle=False, drop_last=sampling_config.drop_last ) else: unshuffled_index_loader = None mp_barrier.wait() keep_running = True while keep_running: try: command, args = task_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if command == MpCommand.SAMPLE_ALL: seeds_index = args if seeds_index is None: loader = unshuffled_index_loader else: loader = DataLoader( seeds_index, batch_size=sampling_config.batch_size, shuffle=False, drop_last=sampling_config.drop_last ) if sampling_config.sampling_type == SamplingType.NODE: for index in loader: dist_sampler.sample_from_nodes(sampler_input[index]) elif sampling_config.sampling_type == SamplingType.LINK: for index in loader: dist_sampler.sample_from_edges(sampler_input[index]) elif sampling_config.sampling_type == SamplingType.SUBGRAPH: for index in loader: dist_sampler.subgraph(sampler_input[index]) dist_sampler.wait_all() with sampling_completed_worker_count.get_lock(): sampling_completed_worker_count.value += 1 # non-atomic, lock is necessary elif command == MpCommand.STOP: keep_running = False else: raise RuntimeError("Unknown command type") except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass if dist_sampler is not None: dist_sampler.shutdown_loop() shutdown_rpc(graceful=False) class DistMpSamplingProducer(object): r""" A subprocess group of distributed sampling workers. Note that this producer is only used for workload with separate sampling and training, all sampled results will be sent to the output channel. """ def __init__(self, data: DistDataset, sampler_input: Union[NodeSamplerInput, EdgeSamplerInput], sampling_config: SamplingConfig, worker_options: _BasicDistSamplingWorkerOptions, output_channel: ChannelBase): self.data = data self.sampler_input = sampler_input.share_memory() self.input_len = len(self.sampler_input) self.sampling_config = sampling_config self.worker_options = worker_options self.worker_options._assign_worker_devices() self.num_workers = self.worker_options.num_workers self.output_channel = output_channel self.sampling_completed_worker_count = mp.Value('I', lock=True) current_ctx = get_context() self.worker_options._set_worker_ranks(current_ctx) self._task_queues = [] self._workers = [] self._barrier = None self._shutdown = False self._worker_seeds_ranges = self._get_worker_seeds_ranges() def init(self): r""" Create the subprocess pool. Init samplers and rpc server. """ if self.sampling_config.seed is not None: seed_everything(self.sampling_config.seed) if not self.sampling_config.shuffle: unshuffled_indexes = self._get_seeds_indexes() else: unshuffled_indexes = [None] * self.num_workers mp_context = mp.get_context('spawn') barrier = mp_context.Barrier(self.num_workers + 1) for rank in range(self.num_workers): task_queue = mp_context.Queue( self.num_workers * self.worker_options.worker_concurrency) self._task_queues.append(task_queue) w = mp_context.Process( target=_sampling_worker_loop, args=(rank, self.data, self.sampler_input, unshuffled_indexes[rank], self.sampling_config, self.worker_options, self.output_channel, task_queue, self.sampling_completed_worker_count, barrier) ) w.daemon = True w.start() self._workers.append(w) barrier.wait() def shutdown(self): r""" Shutdown sampler event loop and rpc server. Join the subprocesses. """ if not self._shutdown: self._shutdown = True try: for q in self._task_queues: q.put((MpCommand.STOP, None)) for w in self._workers: w.join(timeout=MP_STATUS_CHECK_INTERVAL) for q in self._task_queues: q.cancel_join_thread() q.close() finally: for w in self._workers: if w.is_alive(): w.terminate() def produce_all(self): r""" Perform sampling for all input seeds. """ if self.sampling_config.shuffle: seeds_indexes = self._get_seeds_indexes() for rank in range(self.num_workers): seeds_indexes[rank].share_memory_() else: seeds_indexes = [None] * self.num_workers self.sampling_completed_worker_count.value = 0 for rank in range(self.num_workers): self._task_queues[rank].put((MpCommand.SAMPLE_ALL, seeds_indexes[rank])) time.sleep(0.1) def is_all_sampling_completed_and_consumed(self): if self.output_channel.empty(): return self.is_all_sampling_completed() def is_all_sampling_completed(self): return self.sampling_completed_worker_count.value == self.num_workers def _get_worker_seeds_ranges(self): num_worker_batches = [0] * self.num_workers num_total_complete_batches = (self.input_len // self.sampling_config.batch_size) for rank in range(self.num_workers): num_worker_batches[rank] += \ (num_total_complete_batches // self.num_workers) for rank in range(num_total_complete_batches % self.num_workers): num_worker_batches[rank] += 1 index_ranges = [] start = 0 for rank in range(self.num_workers): end = start + num_worker_batches[rank] * self.sampling_config.batch_size if rank == self.num_workers - 1: end = self.input_len index_ranges.append((start, end)) start = end return index_ranges def _get_seeds_indexes(self): if self.sampling_config.shuffle: index = torch.randperm(self.input_len) else: index = torch.arange(self.input_len) seeds_indexes = [] for rank in range(self.num_workers): start, end = self._worker_seeds_ranges[rank] seeds_indexes.append(index[start:end]) return seeds_indexes class DistCollocatedSamplingProducer(object): r""" A sampling producer with a collocated distributed sampler. Note that the sampled results will be returned directly and this producer will be blocking when processing each batch. """ def __init__(self, data: DistDataset, sampler_input: Union[NodeSamplerInput, EdgeSamplerInput], sampling_config: SamplingConfig, worker_options: _BasicDistSamplingWorkerOptions, device: torch.device): self.data = data self.sampler_input = sampler_input self.sampling_config = sampling_config self.worker_options = worker_options self.device = device def init(self): index = torch.arange(len(self.sampler_input)) self._index_loader = DataLoader( index, batch_size=self.sampling_config.batch_size, shuffle=self.sampling_config.shuffle, drop_last=self.sampling_config.drop_last ) self._index_iter = self._index_loader._get_iterator() if self.worker_options.num_rpc_threads is None: num_rpc_threads = min(self.data.num_partitions, 16) else: num_rpc_threads = self.worker_options.num_rpc_threads init_rpc( master_addr=self.worker_options.master_addr, master_port=self.worker_options.master_port, num_rpc_threads=num_rpc_threads, rpc_timeout=self.worker_options.rpc_timeout ) self._collocated_sampler = DistNeighborSampler( self.data, self.sampling_config.num_neighbors, self.sampling_config.with_edge, self.sampling_config.with_neg, self.sampling_config.with_weight, self.sampling_config.edge_dir, self.sampling_config.collect_features, channel=None, use_all2all=self.worker_options.use_all2all, concurrency=1, device=self.device, seed=self.sampling_config.seed ) self._collocated_sampler.start_loop() def shutdown(self): if self._collocated_sampler is not None: self._collocated_sampler.shutdown_loop() def reset(self): self._index_iter._reset(self._index_loader) def sample(self): index = self._index_iter._next_data() batch = self.sampler_input[index] if self.sampling_config.sampling_type == SamplingType.NODE: return self._collocated_sampler.sample_from_nodes(batch) if self.sampling_config.sampling_type == SamplingType.LINK: return self._collocated_sampler.sample_from_edges(batch) if self.sampling_config.sampling_type == SamplingType.SUBGRAPH: return self._collocated_sampler.subgraph(batch) raise NotImplementedError