in hugegraph-ml/src/hugegraph_ml/models/seal.py [0:0]
def sample_subgraph(self, target_nodes):
"""
Args:
target_nodes(Tensor): Tensor of two target nodes
Returns:
subgraph(DGLGraph): subgraph
"""
sample_nodes = [target_nodes]
frontiers = target_nodes
for _ in range(self.hop):
frontiers = self.graph.out_edges(frontiers)[1]
frontiers = torch.unique(frontiers)
sample_nodes.append(frontiers)
sample_nodes = torch.cat(sample_nodes)
sample_nodes = torch.unique(sample_nodes)
subgraph = dgl.node_subgraph(self.graph, sample_nodes)
# Each node should have unique node id in the new subgraph
u_id = int(
torch.nonzero(subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False)
)
v_id = int(
torch.nonzero(subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False)
)
# remove link between target nodes in positive subgraphs.
if subgraph.has_edges_between(u_id, v_id):
link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2]
subgraph.remove_edges(link_id)
if subgraph.has_edges_between(v_id, u_id):
link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2]
subgraph.remove_edges(link_id)
z = drnl_node_labeling(subgraph, u_id, v_id)
subgraph.ndata["z"] = z
return subgraph