in para_graph_sampler/graph_engine/frontend/samplers_ensemble.py [0:0]
def _extract_subgraph_return(self, ret_subg_struct, config_sampler, config_aug, validate_subg):
"""
only for C++ backend
"""
clip = ret_subg_struct.get_num_valid_subg()
info = []
for n in Subgraph.names_data_fields:
r = getattr(ret_subg_struct, f'get_subgraph_{n}')()
info.append([np.asarray(d) for d in r[:clip]])
info_enc = []
for n in EntityEncoding.names_data_fields:
r = getattr(ret_subg_struct, f'get_subgraph_{n}')()
if f"{n}s" not in config_aug:
info_enc.append([np.array([]) for _ in range(clip)])
else:
info_enc.append([np.asarray(d) for d in r[:clip]])
if config_sampler['method'] == 'ppr' and 'k' in config_sampler:
cap_node_subg = int(config_sampler['k'])
num_targets = set([tnp.size for tnp in info[Subgraph.names_data_fields.index('target')]])
assert len(num_targets) == 1
cap_node_subg *= num_targets.pop()
else:
cap_node_subg = self.num_nodes_full
cap_edge_subg = min(self.num_edges_full, cap_node_subg**2)
enc_batch = [
EntityEncoding(
cap_node_subg=cap_node_subg,
cap_edge_subg=cap_edge_subg,
validate=validate_subg,
**dict(zip(EntityEncoding.names_data_fields, sie))
) for sie in zip(*info_enc)
]
return [
Subgraph(
cap_node_full=self.num_nodes_full,
cap_edge_full=self.num_edges_full,
cap_node_subg=cap_node_subg,
cap_edge_subg=cap_edge_subg,
validate=validate_subg,
entity_enc=enc_batch[i],
**dict(zip(Subgraph.names_data_fields, si))
) for i, si in enumerate(zip(*info))
]