SubgraphStruct ParallelSampler::_node_induced_subgraph()

in para_graph_sampler/graph_engine/backend/ParallelSampler.cpp [350:453]


SubgraphStruct ParallelSampler::_node_induced_subgraph(std::unordered_map<NodeType, PPRType>& nodes_touched,
    std::vector<NodeType>& targets, bool include_self_conn, bool include_target_conn, std::set<std::string>& config_aug) {
  SubgraphStruct ret_subg_info;
  std::unordered_map<NodeType, NodeType> orig2subID;     // mapping from original graph node id to subgraph node id
  // first traversal to sort orig ID (potentially makes the python indexing faster)
  std::vector<std::pair<NodeType, PPRType>> temp_origNodeID;
  if (targets.size() == 1) {
    include_target_conn = true;
  }
  for (auto vp : nodes_touched) {
    temp_origNodeID.push_back(vp);      // NOTE: PPR val for the node will be useless if we have more than 1 target
  }
  std::sort(temp_origNodeID.begin(), temp_origNodeID.end());
  for (auto vp : temp_origNodeID) {
    ret_subg_info.origNodeID.push_back(vp.first);
    ret_subg_info.ppr.push_back(vp.second);
  }
  // second traversal to get the mapping
  NodeType cnt_subg_nodes = 0;
  for (auto v : ret_subg_info.origNodeID) {
    orig2subID[v] = cnt_subg_nodes;
    cnt_subg_nodes ++;
  }
  if (fix_target) {
    for (auto t : targets) {
      ret_subg_info.target.push_back(orig2subID[t]);
    }
  }
  // third traversal to build the neighbor list
  cnt_subg_nodes = 0;
  ret_subg_info.indptr.push_back(0);
  for (auto v : ret_subg_info.origNodeID) {
    auto idx_start = graph_full.indptr[v];
    auto idx_end = graph_full.indptr[v+1];
    ret_subg_info.indptr.push_back(0);
    NodeType idx_insert = -1;
    if (include_self_conn) {
      auto idx_self_upper = std::upper_bound(
        graph_full.indices.begin() + idx_start, 
        graph_full.indices.begin() + idx_end, 
        v
      );
      auto idx_self_lower = std::lower_bound(
        graph_full.indices.begin() + idx_start, 
        graph_full.indices.begin() + idx_end, 
        v
      );
      if (idx_self_upper == idx_self_lower) {
        idx_insert = idx_self_upper - graph_full.indices.begin();
      }
    }
    NodeType idx_end_adjusted = idx_insert >= 0 ? idx_end + 1 : idx_end;
    bool passed_self_e = false;
    for (NodeType e = idx_start; e < idx_end_adjusted; e++) {
      NodeType e_adjusted = passed_self_e ? e - 1 : e;
      auto neigh = graph_full.indices[e_adjusted];
      if (e == idx_insert) {
        passed_self_e = true;
        ret_subg_info.indices.push_back(orig2subID[v]);
        ret_subg_info.indptr[cnt_subg_nodes+1] ++;
        ret_subg_info.origEdgeID.push_back(-1);
        ret_subg_info.data.push_back(1.);
      } else if (
        nodes_touched.find(neigh) != nodes_touched.end() && 
        (
          include_target_conn || 
          std::find(targets.begin(), targets.end(), v) == targets.end() || 
          std::find(targets.begin(), targets.end(), neigh) == targets.end()
        )
      ) {
        ret_subg_info.indices.push_back(orig2subID[neigh]);
        ret_subg_info.indptr[cnt_subg_nodes+1] ++;
        ret_subg_info.origEdgeID.push_back(e_adjusted);
        ret_subg_info.data.push_back(1.);
      }
    }
    cnt_subg_nodes ++;
  }
  // fix indptr for a valid CSR
  for (auto i = 0; i < cnt_subg_nodes; i++) {
    ret_subg_info.indptr[i+1] += ret_subg_info.indptr[i];
  }
  // augmentation
  if (config_aug.find("hops") != config_aug.end()) {
    assert(config_aug.find("drnls") == config_aug.end());
    ret_subg_info.compute_hops(-1);
  }
  // compute drnl
  if (config_aug.find("drnls") != config_aug.end()) {
    assert(config_aug.find("hops") == config_aug.end());
    std::vector<NodeType> dx;
    std::vector<NodeType> dy;
    ret_subg_info.compute_hops(0);
    dx.swap(ret_subg_info.hop);
    ret_subg_info.compute_hops(1);
    dy.swap(ret_subg_info.hop);
    auto num_subg_node = ret_subg_info.indptr.size() - 1;
    ret_subg_info.drnl.resize(num_subg_node);
    for (auto i = 0; i < dx.size(); i++) {
      ret_subg_info.drnl[i] = ret_subg_info.compute_drnl_single(dx[i], dy[i]);
    }
  }
  return ret_subg_info;
}