in graphlearn_torch/python/distributed/dist_sampling_producer.py [0:0]
def __init__(self,
data: DistDataset,
sampler_input: Union[NodeSamplerInput, EdgeSamplerInput],
sampling_config: SamplingConfig,
worker_options: _BasicDistSamplingWorkerOptions,
output_channel: ChannelBase):
self.data = data
self.sampler_input = sampler_input.share_memory()
self.input_len = len(self.sampler_input)
self.sampling_config = sampling_config
self.worker_options = worker_options
self.worker_options._assign_worker_devices()
self.num_workers = self.worker_options.num_workers
self.output_channel = output_channel
self.sampling_completed_worker_count = mp.Value('I', lock=True)
current_ctx = get_context()
self.worker_options._set_worker_ranks(current_ctx)
self._task_queues = []
self._workers = []
self._barrier = None
self._shutdown = False
self._worker_seeds_ranges = self._get_worker_seeds_ranges()