def _extract_subgraph_return()

in para_graph_sampler/graph_engine/frontend/samplers_ensemble.py [0:0]


    def _extract_subgraph_return(self, ret_subg_struct, config_sampler, config_aug, validate_subg):
        """
        only for C++ backend
        """
        clip = ret_subg_struct.get_num_valid_subg()
        info = []
        for n in Subgraph.names_data_fields:
            r = getattr(ret_subg_struct, f'get_subgraph_{n}')()
            info.append([np.asarray(d) for d in r[:clip]])
        info_enc = []
        for n in EntityEncoding.names_data_fields:
            r = getattr(ret_subg_struct, f'get_subgraph_{n}')()
            if f"{n}s" not in config_aug:
                info_enc.append([np.array([]) for _ in range(clip)])
            else:
                info_enc.append([np.asarray(d) for d in r[:clip]])
        if config_sampler['method'] == 'ppr' and 'k' in config_sampler:
            cap_node_subg = int(config_sampler['k'])
            num_targets = set([tnp.size for tnp in info[Subgraph.names_data_fields.index('target')]])
            assert len(num_targets) == 1
            cap_node_subg *= num_targets.pop()
        else:
            cap_node_subg = self.num_nodes_full
        cap_edge_subg = min(self.num_edges_full, cap_node_subg**2)
        enc_batch = [
            EntityEncoding(
                cap_node_subg=cap_node_subg,
                cap_edge_subg=cap_edge_subg,
                validate=validate_subg,
                **dict(zip(EntityEncoding.names_data_fields, sie))
            ) for sie in zip(*info_enc)
        ]
        return [
            Subgraph(
                cap_node_full=self.num_nodes_full,
                cap_edge_full=self.num_edges_full,
                cap_node_subg=cap_node_subg,
                cap_edge_subg=cap_edge_subg,
                validate=validate_subg,
                entity_enc=enc_batch[i],
                **dict(zip(Subgraph.names_data_fields, si))
            ) for i, si in enumerate(zip(*info))
        ]