in graphlearn_torch/python/distributed/dist_sampling_producer.py [0:0]
def init(self):
index = torch.arange(len(self.sampler_input))
self._index_loader = DataLoader(
index,
batch_size=self.sampling_config.batch_size,
shuffle=self.sampling_config.shuffle,
drop_last=self.sampling_config.drop_last
)
self._index_iter = self._index_loader._get_iterator()
if self.worker_options.num_rpc_threads is None:
num_rpc_threads = min(self.data.num_partitions, 16)
else:
num_rpc_threads = self.worker_options.num_rpc_threads
init_rpc(
master_addr=self.worker_options.master_addr,
master_port=self.worker_options.master_port,
num_rpc_threads=num_rpc_threads,
rpc_timeout=self.worker_options.rpc_timeout
)
self._collocated_sampler = DistNeighborSampler(
self.data, self.sampling_config.num_neighbors,
self.sampling_config.with_edge, self.sampling_config.with_neg,
self.sampling_config.with_weight,
self.sampling_config.edge_dir, self.sampling_config.collect_features,
channel=None, use_all2all=self.worker_options.use_all2all,
concurrency=1, device=self.device,
seed=self.sampling_config.seed
)
self._collocated_sampler.start_loop()