SubgraphStruct ParallelSampler::ppr_stochastic()

in para_graph_sampler/graph_engine/backend/ParallelSampler.cpp [603:650]


SubgraphStruct ParallelSampler::ppr_stochastic(
  std::vector<NodeType>& targets,
  std::unordered_map<std::string, std::string>& config,
  std::set<std::string>& config_aug
) {
  int k = std::stoi(config.at("k"));
  float threshold = std::stod(config.at("threshold"));
  // determine the sample size (based on the threshold value)
  std::vector<int> sample_size;
  for (auto t : targets) {
    auto size_neigh = std::min(static_cast<int64_t>(k), static_cast<int64_t>(top_ppr_neighs[t].size()));
    // option 1: keep the subg sizes consistent among TRAIN, VALID, TEST
    PPRType max_ppr = size_neigh <= 1 ? 0 : top_ppr_scores[t][1];
    int cnt_target = 0;
    for (auto i = 0; i < size_neigh; i++) {
      cnt_target ++;
      if (max_ppr == 0 || top_ppr_scores[t][i] / max_ppr < threshold) {
        break;
      }
    }
    sample_size.push_back(cnt_target);
    // option 2: this might lead to larger training graph sizes. 
    // sample_size.push_back(size_neigh);    // we sample by dist anyways, so probably don't need the cut-off on threshold
  }
  // sample by the distribution defined by top_ppr_scores
  std::unordered_map<NodeType, PPRType> nodes_touched;
  int cnt_t = 0;
  for (auto t : targets) {
    // generate sample_size number of nodes according to the weight defined by top_ppr_scores[t]
    std::vector<std::pair<double, int>> sample_w_idx;    // randomly generated weight, element index
    int idx_sample = 0;
    for (auto s : top_ppr_scores[t]) {
      auto u = static_cast<double>(rand() / (RAND_MAX));      // random number between 0, 1
      sample_w_idx.push_back(std::make_pair(-std::pow(u, 1 / s), idx_sample));
      idx_sample ++;
    }
    std::nth_element(sample_w_idx.begin(), sample_w_idx.begin() + sample_size[cnt_t], sample_w_idx.end());
    for (auto i_sel = 0; i_sel < sample_size[cnt_t]; i_sel ++) {
      auto sel = sample_w_idx[i_sel];
      nodes_touched[top_ppr_neighs[t][sel.second]] = top_ppr_scores[t][sel.second];
    }
    cnt_t ++;
  }
  auto add_self_edge = _extract_bool_config(config, "add_self_edge", false);
  auto include_target_conn = _extract_bool_config(config, "include_target_conn", false);
  auto ret = _node_induced_subgraph(nodes_touched, targets, add_self_edge, include_target_conn, config_aug);
  return ret;
}