in utils.py [0:0]
def construct_pyg_graph(node_ids, adj, dists, node_features, y, node_label='drnl'):
# Construct a pytorch_geometric graph from a scipy csr adjacency matrix.
u, v, r = ssp.find(adj)
num_nodes = adj.shape[0]
node_ids = torch.LongTensor(node_ids)
u, v = torch.LongTensor(u), torch.LongTensor(v)
r = torch.LongTensor(r)
edge_index = torch.stack([u, v], 0)
edge_weight = r.to(torch.float)
y = torch.tensor([y])
if node_label == 'drnl': # DRNL
z = drnl_node_labeling(adj, 0, 1)
elif node_label == 'hop': # mininum distance to src and dst
z = torch.tensor(dists)
elif node_label == 'zo': # zero-one labeling trick
z = (torch.tensor(dists)==0).to(torch.long)
elif node_label == 'de': # distance encoding
z = de_node_labeling(adj, 0, 1)
elif node_label == 'de+':
z = de_plus_node_labeling(adj, 0, 1)
elif node_label == 'degree': # this is technically not a valid labeling trick
z = torch.tensor(adj.sum(axis=0)).squeeze(0)
z[z>100] = 100 # limit the maximum label to 100
else:
z = torch.zeros(len(dists), dtype=torch.long)
data = Data(node_features, edge_index, edge_weight=edge_weight, y=y, z=z,
node_id=node_ids, num_nodes=num_nodes)
return data