in graphlearn_torch/python/distributed/dist_sampling_producer.py [0:0]
def init(self):
r""" Create the subprocess pool. Init samplers and rpc server.
"""
if self.sampling_config.seed is not None:
seed_everything(self.sampling_config.seed)
if not self.sampling_config.shuffle:
unshuffled_indexes = self._get_seeds_indexes()
else:
unshuffled_indexes = [None] * self.num_workers
mp_context = mp.get_context('spawn')
barrier = mp_context.Barrier(self.num_workers + 1)
for rank in range(self.num_workers):
task_queue = mp_context.Queue(
self.num_workers * self.worker_options.worker_concurrency)
self._task_queues.append(task_queue)
w = mp_context.Process(
target=_sampling_worker_loop,
args=(rank, self.data, self.sampler_input, unshuffled_indexes[rank],
self.sampling_config, self.worker_options, self.output_channel,
task_queue, self.sampling_completed_worker_count, barrier)
)
w.daemon = True
w.start()
self._workers.append(w)
barrier.wait()