graphlearn_torch/python/distributed/dist_sampling_producer.py (285 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import queue
import time, datetime
from enum import Enum
from typing import Optional, Union
import torch
import torch.multiprocessing as mp
from torch._C import _set_worker_signal_handlers
from torch.utils.data.dataloader import DataLoader
from ..channel import ChannelBase
from ..sampler import (
NodeSamplerInput, EdgeSamplerInput, SamplingType, SamplingConfig
)
from ..utils import ensure_device
from ..utils import seed_everything
from ..distributed.dist_context import get_context
from .dist_context import init_worker_group
from .dist_dataset import DistDataset
from .dist_neighbor_sampler import DistNeighborSampler
from .dist_options import _BasicDistSamplingWorkerOptions
from .rpc import init_rpc, shutdown_rpc
MP_STATUS_CHECK_INTERVAL = 5.0
r""" Interval (in seconds) to check status of processes to avoid hanging in
multiprocessing sampling.
"""
class MpCommand(Enum):
r""" Enum class for multiprocessing sampling command
"""
SAMPLE_ALL = 0
STOP = 1
def _sampling_worker_loop(rank,
data: DistDataset,
sampler_input: Union[NodeSamplerInput, EdgeSamplerInput],
unshuffled_index: Optional[torch.Tensor],
sampling_config: SamplingConfig,
worker_options: _BasicDistSamplingWorkerOptions,
channel: ChannelBase,
task_queue: mp.Queue,
sampling_completed_worker_count: mp.Value,
mp_barrier):
r""" Subprocess work loop for sampling worker.
"""
dist_sampler = None
try:
init_worker_group(
world_size=worker_options.worker_world_size,
rank=worker_options.worker_ranks[rank],
group_name='_sampling_worker_subprocess'
)
if worker_options.use_all2all:
torch.distributed.init_process_group(
backend='gloo',
timeout=datetime.timedelta(seconds=worker_options.rpc_timeout),
rank=worker_options.worker_ranks[rank],
world_size=worker_options.worker_world_size,
init_method='tcp://{}:{}'.format(worker_options.master_addr, worker_options.master_port)
)
if worker_options.num_rpc_threads is None:
num_rpc_threads = min(data.num_partitions, 16)
else:
num_rpc_threads = worker_options.num_rpc_threads
current_device = worker_options.worker_devices[rank]
ensure_device(current_device)
_set_worker_signal_handlers()
torch.set_num_threads(num_rpc_threads + 1)
init_rpc(
master_addr=worker_options.master_addr,
master_port=worker_options.master_port,
num_rpc_threads=num_rpc_threads,
rpc_timeout=worker_options.rpc_timeout
)
if sampling_config.seed is not None:
seed_everything(sampling_config.seed)
dist_sampler = DistNeighborSampler(
data, sampling_config.num_neighbors, sampling_config.with_edge,
sampling_config.with_neg, sampling_config.with_weight,
sampling_config.edge_dir, sampling_config.collect_features, channel,
worker_options.use_all2all, worker_options.worker_concurrency,
current_device, seed=sampling_config.seed
)
dist_sampler.start_loop()
if unshuffled_index is not None:
unshuffled_index_loader = DataLoader(
unshuffled_index, batch_size=sampling_config.batch_size,
shuffle=False, drop_last=sampling_config.drop_last
)
else:
unshuffled_index_loader = None
mp_barrier.wait()
keep_running = True
while keep_running:
try:
command, args = task_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if command == MpCommand.SAMPLE_ALL:
seeds_index = args
if seeds_index is None:
loader = unshuffled_index_loader
else:
loader = DataLoader(
seeds_index, batch_size=sampling_config.batch_size,
shuffle=False, drop_last=sampling_config.drop_last
)
if sampling_config.sampling_type == SamplingType.NODE:
for index in loader:
dist_sampler.sample_from_nodes(sampler_input[index])
elif sampling_config.sampling_type == SamplingType.LINK:
for index in loader:
dist_sampler.sample_from_edges(sampler_input[index])
elif sampling_config.sampling_type == SamplingType.SUBGRAPH:
for index in loader:
dist_sampler.subgraph(sampler_input[index])
dist_sampler.wait_all()
with sampling_completed_worker_count.get_lock():
sampling_completed_worker_count.value += 1 # non-atomic, lock is necessary
elif command == MpCommand.STOP:
keep_running = False
else:
raise RuntimeError("Unknown command type")
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
if dist_sampler is not None:
dist_sampler.shutdown_loop()
shutdown_rpc(graceful=False)
class DistMpSamplingProducer(object):
r""" A subprocess group of distributed sampling workers.
Note that this producer is only used for workload with separate sampling
and training, all sampled results will be sent to the output channel.
"""
def __init__(self,
data: DistDataset,
sampler_input: Union[NodeSamplerInput, EdgeSamplerInput],
sampling_config: SamplingConfig,
worker_options: _BasicDistSamplingWorkerOptions,
output_channel: ChannelBase):
self.data = data
self.sampler_input = sampler_input.share_memory()
self.input_len = len(self.sampler_input)
self.sampling_config = sampling_config
self.worker_options = worker_options
self.worker_options._assign_worker_devices()
self.num_workers = self.worker_options.num_workers
self.output_channel = output_channel
self.sampling_completed_worker_count = mp.Value('I', lock=True)
current_ctx = get_context()
self.worker_options._set_worker_ranks(current_ctx)
self._task_queues = []
self._workers = []
self._barrier = None
self._shutdown = False
self._worker_seeds_ranges = self._get_worker_seeds_ranges()
def init(self):
r""" Create the subprocess pool. Init samplers and rpc server.
"""
if self.sampling_config.seed is not None:
seed_everything(self.sampling_config.seed)
if not self.sampling_config.shuffle:
unshuffled_indexes = self._get_seeds_indexes()
else:
unshuffled_indexes = [None] * self.num_workers
mp_context = mp.get_context('spawn')
barrier = mp_context.Barrier(self.num_workers + 1)
for rank in range(self.num_workers):
task_queue = mp_context.Queue(
self.num_workers * self.worker_options.worker_concurrency)
self._task_queues.append(task_queue)
w = mp_context.Process(
target=_sampling_worker_loop,
args=(rank, self.data, self.sampler_input, unshuffled_indexes[rank],
self.sampling_config, self.worker_options, self.output_channel,
task_queue, self.sampling_completed_worker_count, barrier)
)
w.daemon = True
w.start()
self._workers.append(w)
barrier.wait()
def shutdown(self):
r""" Shutdown sampler event loop and rpc server. Join the subprocesses.
"""
if not self._shutdown:
self._shutdown = True
try:
for q in self._task_queues:
q.put((MpCommand.STOP, None))
for w in self._workers:
w.join(timeout=MP_STATUS_CHECK_INTERVAL)
for q in self._task_queues:
q.cancel_join_thread()
q.close()
finally:
for w in self._workers:
if w.is_alive():
w.terminate()
def produce_all(self):
r""" Perform sampling for all input seeds.
"""
if self.sampling_config.shuffle:
seeds_indexes = self._get_seeds_indexes()
for rank in range(self.num_workers):
seeds_indexes[rank].share_memory_()
else:
seeds_indexes = [None] * self.num_workers
self.sampling_completed_worker_count.value = 0
for rank in range(self.num_workers):
self._task_queues[rank].put((MpCommand.SAMPLE_ALL, seeds_indexes[rank]))
time.sleep(0.1)
def is_all_sampling_completed_and_consumed(self):
if self.output_channel.empty():
return self.is_all_sampling_completed()
def is_all_sampling_completed(self):
return self.sampling_completed_worker_count.value == self.num_workers
def _get_worker_seeds_ranges(self):
num_worker_batches = [0] * self.num_workers
num_total_complete_batches = (self.input_len // self.sampling_config.batch_size)
for rank in range(self.num_workers):
num_worker_batches[rank] += \
(num_total_complete_batches // self.num_workers)
for rank in range(num_total_complete_batches % self.num_workers):
num_worker_batches[rank] += 1
index_ranges = []
start = 0
for rank in range(self.num_workers):
end = start + num_worker_batches[rank] * self.sampling_config.batch_size
if rank == self.num_workers - 1:
end = self.input_len
index_ranges.append((start, end))
start = end
return index_ranges
def _get_seeds_indexes(self):
if self.sampling_config.shuffle:
index = torch.randperm(self.input_len)
else:
index = torch.arange(self.input_len)
seeds_indexes = []
for rank in range(self.num_workers):
start, end = self._worker_seeds_ranges[rank]
seeds_indexes.append(index[start:end])
return seeds_indexes
class DistCollocatedSamplingProducer(object):
r""" A sampling producer with a collocated distributed sampler.
Note that the sampled results will be returned directly and this producer
will be blocking when processing each batch.
"""
def __init__(self,
data: DistDataset,
sampler_input: Union[NodeSamplerInput, EdgeSamplerInput],
sampling_config: SamplingConfig,
worker_options: _BasicDistSamplingWorkerOptions,
device: torch.device):
self.data = data
self.sampler_input = sampler_input
self.sampling_config = sampling_config
self.worker_options = worker_options
self.device = device
def init(self):
index = torch.arange(len(self.sampler_input))
self._index_loader = DataLoader(
index,
batch_size=self.sampling_config.batch_size,
shuffle=self.sampling_config.shuffle,
drop_last=self.sampling_config.drop_last
)
self._index_iter = self._index_loader._get_iterator()
if self.worker_options.num_rpc_threads is None:
num_rpc_threads = min(self.data.num_partitions, 16)
else:
num_rpc_threads = self.worker_options.num_rpc_threads
init_rpc(
master_addr=self.worker_options.master_addr,
master_port=self.worker_options.master_port,
num_rpc_threads=num_rpc_threads,
rpc_timeout=self.worker_options.rpc_timeout
)
self._collocated_sampler = DistNeighborSampler(
self.data, self.sampling_config.num_neighbors,
self.sampling_config.with_edge, self.sampling_config.with_neg,
self.sampling_config.with_weight,
self.sampling_config.edge_dir, self.sampling_config.collect_features,
channel=None, use_all2all=self.worker_options.use_all2all,
concurrency=1, device=self.device,
seed=self.sampling_config.seed
)
self._collocated_sampler.start_loop()
def shutdown(self):
if self._collocated_sampler is not None:
self._collocated_sampler.shutdown_loop()
def reset(self):
self._index_iter._reset(self._index_loader)
def sample(self):
index = self._index_iter._next_data()
batch = self.sampler_input[index]
if self.sampling_config.sampling_type == SamplingType.NODE:
return self._collocated_sampler.sample_from_nodes(batch)
if self.sampling_config.sampling_type == SamplingType.LINK:
return self._collocated_sampler.sample_from_edges(batch)
if self.sampling_config.sampling_type == SamplingType.SUBGRAPH:
return self._collocated_sampler.subgraph(batch)
raise NotImplementedError