in graphlog/dataset.py [0:0]
def load_data(self, rule_world: str) -> None:
"""
Load graph data in Pytorch Geometric
:return:
"""
for mode in self.graphs:
graph_file = os.path.join(rule_world, "{}.jsonl".format(mode))
graphs = []
with open(graph_file, "r") as fp:
for line in fp:
graphs.append(json.loads(line))
self.json_graphs[mode] = graphs
for gi, gs in enumerate(graphs):
# Graph with Edge attributes
node2id: Dict[str, int] = {}
edges = []
edge_attr = []
for (src, dst, rel) in gs["edges"]:
if src not in node2id:
node2id[src] = len(node2id)
if dst not in node2id:
node2id[dst] = len(node2id)
edges.append([node2id[src], node2id[dst]])
target = self.get_label2id(rel)
edge_attr.append(target)
(src, dst, rel) = gs["query"]
self.queries[mode].append((node2id[src], node2id[dst]))
target = self.get_label2id(rel)
self.labels[mode].append(target)
self.label_set.add(target)
# self.path_len[mode].append(len(gs["rules"]))
x = torch.arange(len(node2id)).unsqueeze(1)
edge_index = list(zip(*edges))
edge_index = torch.LongTensor(edge_index) # type: ignore
# 2 x num_edges
assert edge_index.dim() == 2 # type: ignore
geo_data = GeometricData(
x=x,
edge_index=edge_index,
edge_attr=torch.tensor(edge_attr),
y=torch.tensor([target]),
)
self.graphs[mode].append(geo_data)
# load the meta graph
meta_graph_file = os.path.join(rule_world, "meta_graph.jsonl")
if os.path.exists(meta_graph_file):
with open(meta_graph_file, "r") as fp:
meta_graph = json.loads(fp.read())
self.json_meta_graph = meta_graph
edges = []
elem_edges = meta_graph["edges"]
# populate edge ids
for elem in elem_edges:
if elem[0] not in node2id:
node2id[elem[0]] = len(node2id)
if elem[1] not in node2id:
node2id[elem[1]] = len(node2id)
edge_mapping = torch.zeros(
(len(self.label2id), len(node2id) + len(elem_edges))
).long()
num_nodes = len(node2id)
edge_ct = num_nodes
edge_indicator = [0 for ni in range(num_nodes)]
for ei, elem in enumerate(elem_edges):
edges.append([node2id[elem[0]], num_nodes + ei])
edges.append([num_nodes + ei, node2id[elem[1]]])
edge_mapping[self.get_label2id(elem[2])][num_nodes + ei] = 1
edge_ct += 1
# NOTE: We are adding 1 to the edge indicator to keep the first position common for nodes
edge_indicator.append(self.get_label2id(elem[2]) + 1)
x = torch.arange(edge_ct).unsqueeze(1)
edge_index = list(zip(*edges))
edge_index = torch.LongTensor(edge_index) # type: ignore
# 2 x num_edges
if edge_index.dim() != 2: # type: ignore
raise AssertionError("edge index dimension should be 2")
edge_mapping = edge_mapping.unsqueeze(0) # 1 x num_unique_edges x dim
self.world_graph = GeometricData(
x=x,
edge_index=edge_index,
edge_indicator=torch.tensor(edge_indicator),
edge_mapping=edge_mapping,
)
for key in self.queries:
self.queries[key] = np.asarray(self.queries[key])
for key in self.labels:
self.labels[key] = np.asarray(self.labels[key])