in seal_link_pred.py [0:0]
def process(self):
pos_edge, neg_edge = get_pos_neg_edges(self.split, self.split_edge,
self.data.edge_index,
self.data.num_nodes,
self.percent)
if self.use_coalesce: # compress mutli-edge into edge with weight
self.data.edge_index, self.data.edge_weight = coalesce(
self.data.edge_index, self.data.edge_weight,
self.data.num_nodes, self.data.num_nodes)
if 'edge_weight' in self.data:
edge_weight = self.data.edge_weight.view(-1)
else:
edge_weight = torch.ones(self.data.edge_index.size(1), dtype=int)
A = ssp.csr_matrix(
(edge_weight, (self.data.edge_index[0], self.data.edge_index[1])),
shape=(self.data.num_nodes, self.data.num_nodes)
)
if self.directed:
A_csc = A.tocsc()
else:
A_csc = None
# Extract enclosing subgraphs for pos and neg edges
pos_list = extract_enclosing_subgraphs(
pos_edge, A, self.data.x, 1, self.num_hops, self.node_label,
self.ratio_per_hop, self.max_nodes_per_hop, self.directed, A_csc)
neg_list = extract_enclosing_subgraphs(
neg_edge, A, self.data.x, 0, self.num_hops, self.node_label,
self.ratio_per_hop, self.max_nodes_per_hop, self.directed, A_csc)
torch.save(self.collate(pos_list + neg_list), self.processed_paths[0])
del pos_list, neg_list