def init()

in graphlearn_torch/python/distributed/dist_sampling_producer.py [0:0]


  def init(self):
    index = torch.arange(len(self.sampler_input))

    self._index_loader = DataLoader(
      index,
      batch_size=self.sampling_config.batch_size,
      shuffle=self.sampling_config.shuffle,
      drop_last=self.sampling_config.drop_last
    )
    self._index_iter = self._index_loader._get_iterator()

    if self.worker_options.num_rpc_threads is None:
      num_rpc_threads = min(self.data.num_partitions, 16)
    else:
      num_rpc_threads = self.worker_options.num_rpc_threads

    init_rpc(
      master_addr=self.worker_options.master_addr,
      master_port=self.worker_options.master_port,
      num_rpc_threads=num_rpc_threads,
      rpc_timeout=self.worker_options.rpc_timeout
    )

    self._collocated_sampler = DistNeighborSampler(
      self.data, self.sampling_config.num_neighbors,
      self.sampling_config.with_edge, self.sampling_config.with_neg,
      self.sampling_config.with_weight,
      self.sampling_config.edge_dir, self.sampling_config.collect_features,
      channel=None, use_all2all=self.worker_options.use_all2all,
      concurrency=1, device=self.device,
      seed=self.sampling_config.seed
    )
    self._collocated_sampler.start_loop()