std::vector ParallelSampler::parallel_sampler_ensemble()

in para_graph_sampler/graph_engine/backend/ParallelSampler.cpp [662:704]


std::vector<SubgraphStructVec> ParallelSampler::parallel_sampler_ensemble(
     std::vector<std::unordered_map<std::string, std::string>> configs_samplers,
     std::vector<std::set<std::string>> configs_aug) {
  cleanup_history_subgraphs_ensemble();
  assert(configs_samplers.size() == subgraphs_ensemble.size());
  int num_roots = 0;
  for (auto& cfg : configs_samplers) {
    if (!num_roots) {num_roots = std::stoi(cfg.at("num_roots"));}
    else {assert(num_roots == std::stoi(cfg.at("num_roots")));}
  }
  std::vector<std::vector<NodeType>> targets = _get_roots_p(num_roots);
  int num_sampler_cur_batch = targets.size();
  int cnt_ensemble = 0;
  for (auto& cfg : configs_samplers) {
    if (cfg.find("method") == cfg.end()) {
      std::cerr << "[C++ parallel sampler]: need to have the 'method' key in the config" << std::endl;
      exit(1);
    }
    bool return_target_only = _extract_bool_config(cfg, "return_target_only", false);
    #pragma omp parallel for
    for (int p = 0; p < num_sampler_cur_batch; p++) {
      SubgraphStruct subgraph_new;
      if (!return_target_only) {
        if (cfg.at("method").compare("khop") == 0) {
          subgraph_new = khop(targets[p], cfg, configs_aug[cnt_ensemble]);
        } else if (cfg.at("method").compare("ppr") == 0) {
          subgraph_new = ppr(targets[p], cfg, configs_aug[cnt_ensemble]);
        } else if (cfg.at("method").compare("ppr_st") == 0) {
          subgraph_new = ppr_stochastic(targets[p], cfg, configs_aug[cnt_ensemble]);
        } else if (cfg.at("method").compare("nodeIID") == 0) {
          subgraph_new = nodeIID(targets[p], configs_aug[cnt_ensemble]);
        }
      } else {
        subgraph_new = dummy_sampler(targets[p]);
      }
      subgraphs_ensemble[cnt_ensemble].add_one_subgraph_vec(subgraph_new, p);
      #pragma omp atomic
      subgraphs_ensemble[cnt_ensemble].num_valid_subg_cur_batch += 1;
    }
    cnt_ensemble ++;
  }
  return subgraphs_ensemble;
}