in para_graph_sampler/graph_engine/backend/ParallelSampler.cpp [350:453]
SubgraphStruct ParallelSampler::_node_induced_subgraph(std::unordered_map<NodeType, PPRType>& nodes_touched,
std::vector<NodeType>& targets, bool include_self_conn, bool include_target_conn, std::set<std::string>& config_aug) {
SubgraphStruct ret_subg_info;
std::unordered_map<NodeType, NodeType> orig2subID; // mapping from original graph node id to subgraph node id
// first traversal to sort orig ID (potentially makes the python indexing faster)
std::vector<std::pair<NodeType, PPRType>> temp_origNodeID;
if (targets.size() == 1) {
include_target_conn = true;
}
for (auto vp : nodes_touched) {
temp_origNodeID.push_back(vp); // NOTE: PPR val for the node will be useless if we have more than 1 target
}
std::sort(temp_origNodeID.begin(), temp_origNodeID.end());
for (auto vp : temp_origNodeID) {
ret_subg_info.origNodeID.push_back(vp.first);
ret_subg_info.ppr.push_back(vp.second);
}
// second traversal to get the mapping
NodeType cnt_subg_nodes = 0;
for (auto v : ret_subg_info.origNodeID) {
orig2subID[v] = cnt_subg_nodes;
cnt_subg_nodes ++;
}
if (fix_target) {
for (auto t : targets) {
ret_subg_info.target.push_back(orig2subID[t]);
}
}
// third traversal to build the neighbor list
cnt_subg_nodes = 0;
ret_subg_info.indptr.push_back(0);
for (auto v : ret_subg_info.origNodeID) {
auto idx_start = graph_full.indptr[v];
auto idx_end = graph_full.indptr[v+1];
ret_subg_info.indptr.push_back(0);
NodeType idx_insert = -1;
if (include_self_conn) {
auto idx_self_upper = std::upper_bound(
graph_full.indices.begin() + idx_start,
graph_full.indices.begin() + idx_end,
v
);
auto idx_self_lower = std::lower_bound(
graph_full.indices.begin() + idx_start,
graph_full.indices.begin() + idx_end,
v
);
if (idx_self_upper == idx_self_lower) {
idx_insert = idx_self_upper - graph_full.indices.begin();
}
}
NodeType idx_end_adjusted = idx_insert >= 0 ? idx_end + 1 : idx_end;
bool passed_self_e = false;
for (NodeType e = idx_start; e < idx_end_adjusted; e++) {
NodeType e_adjusted = passed_self_e ? e - 1 : e;
auto neigh = graph_full.indices[e_adjusted];
if (e == idx_insert) {
passed_self_e = true;
ret_subg_info.indices.push_back(orig2subID[v]);
ret_subg_info.indptr[cnt_subg_nodes+1] ++;
ret_subg_info.origEdgeID.push_back(-1);
ret_subg_info.data.push_back(1.);
} else if (
nodes_touched.find(neigh) != nodes_touched.end() &&
(
include_target_conn ||
std::find(targets.begin(), targets.end(), v) == targets.end() ||
std::find(targets.begin(), targets.end(), neigh) == targets.end()
)
) {
ret_subg_info.indices.push_back(orig2subID[neigh]);
ret_subg_info.indptr[cnt_subg_nodes+1] ++;
ret_subg_info.origEdgeID.push_back(e_adjusted);
ret_subg_info.data.push_back(1.);
}
}
cnt_subg_nodes ++;
}
// fix indptr for a valid CSR
for (auto i = 0; i < cnt_subg_nodes; i++) {
ret_subg_info.indptr[i+1] += ret_subg_info.indptr[i];
}
// augmentation
if (config_aug.find("hops") != config_aug.end()) {
assert(config_aug.find("drnls") == config_aug.end());
ret_subg_info.compute_hops(-1);
}
// compute drnl
if (config_aug.find("drnls") != config_aug.end()) {
assert(config_aug.find("hops") == config_aug.end());
std::vector<NodeType> dx;
std::vector<NodeType> dy;
ret_subg_info.compute_hops(0);
dx.swap(ret_subg_info.hop);
ret_subg_info.compute_hops(1);
dy.swap(ret_subg_info.hop);
auto num_subg_node = ret_subg_info.indptr.size() - 1;
ret_subg_info.drnl.resize(num_subg_node);
for (auto i = 0; i < dx.size(); i++) {
ret_subg_info.drnl[i] = ret_subg_info.compute_drnl_single(dx[i], dy[i]);
}
}
return ret_subg_info;
}