def sample_subgraph()

in hugegraph-ml/src/hugegraph_ml/models/seal.py [0:0]


    def sample_subgraph(self, target_nodes):
        """
        Args:
            target_nodes(Tensor): Tensor of two target nodes
        Returns:
            subgraph(DGLGraph): subgraph
        """
        sample_nodes = [target_nodes]
        frontiers = target_nodes

        for _ in range(self.hop):
            frontiers = self.graph.out_edges(frontiers)[1]
            frontiers = torch.unique(frontiers)
            sample_nodes.append(frontiers)

        sample_nodes = torch.cat(sample_nodes)
        sample_nodes = torch.unique(sample_nodes)
        subgraph = dgl.node_subgraph(self.graph, sample_nodes)

        # Each node should have unique node id in the new subgraph
        u_id = int(
            torch.nonzero(subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False)
        )
        v_id = int(
            torch.nonzero(subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False)
        )

        # remove link between target nodes in positive subgraphs.
        if subgraph.has_edges_between(u_id, v_id):
            link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2]
            subgraph.remove_edges(link_id)
        if subgraph.has_edges_between(v_id, u_id):
            link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2]
            subgraph.remove_edges(link_id)

        z = drnl_node_labeling(subgraph, u_id, v_id)
        subgraph.ndata["z"] = z

        return subgraph