def load_data()

in graphlog/dataset.py [0:0]


    def load_data(self, rule_world: str) -> None:
        """
        Load graph data in Pytorch Geometric
        :return:
        """

        for mode in self.graphs:
            graph_file = os.path.join(rule_world, "{}.jsonl".format(mode))
            graphs = []
            with open(graph_file, "r") as fp:
                for line in fp:
                    graphs.append(json.loads(line))
            self.json_graphs[mode] = graphs
            for gi, gs in enumerate(graphs):
                # Graph with Edge attributes
                node2id: Dict[str, int] = {}
                edges = []
                edge_attr = []
                for (src, dst, rel) in gs["edges"]:
                    if src not in node2id:
                        node2id[src] = len(node2id)
                    if dst not in node2id:
                        node2id[dst] = len(node2id)
                    edges.append([node2id[src], node2id[dst]])
                    target = self.get_label2id(rel)
                    edge_attr.append(target)

                (src, dst, rel) = gs["query"]
                self.queries[mode].append((node2id[src], node2id[dst]))
                target = self.get_label2id(rel)
                self.labels[mode].append(target)
                self.label_set.add(target)
                # self.path_len[mode].append(len(gs["rules"]))
                x = torch.arange(len(node2id)).unsqueeze(1)

                edge_index = list(zip(*edges))
                edge_index = torch.LongTensor(edge_index)  # type: ignore
                # 2 x num_edges
                assert edge_index.dim() == 2  # type: ignore
                geo_data = GeometricData(
                    x=x,
                    edge_index=edge_index,
                    edge_attr=torch.tensor(edge_attr),
                    y=torch.tensor([target]),
                )
                self.graphs[mode].append(geo_data)

        # load the meta graph
        meta_graph_file = os.path.join(rule_world, "meta_graph.jsonl")
        if os.path.exists(meta_graph_file):
            with open(meta_graph_file, "r") as fp:
                meta_graph = json.loads(fp.read())
                self.json_meta_graph = meta_graph
                edges = []
                elem_edges = meta_graph["edges"]
                # populate edge ids
                for elem in elem_edges:
                    if elem[0] not in node2id:
                        node2id[elem[0]] = len(node2id)
                    if elem[1] not in node2id:
                        node2id[elem[1]] = len(node2id)
                edge_mapping = torch.zeros(
                    (len(self.label2id), len(node2id) + len(elem_edges))
                ).long()
                num_nodes = len(node2id)
                edge_ct = num_nodes
                edge_indicator = [0 for ni in range(num_nodes)]
                for ei, elem in enumerate(elem_edges):
                    edges.append([node2id[elem[0]], num_nodes + ei])
                    edges.append([num_nodes + ei, node2id[elem[1]]])
                    edge_mapping[self.get_label2id(elem[2])][num_nodes + ei] = 1
                    edge_ct += 1
                    # NOTE: We are adding 1 to the edge indicator to keep the first position common for nodes
                    edge_indicator.append(self.get_label2id(elem[2]) + 1)
                x = torch.arange(edge_ct).unsqueeze(1)
                edge_index = list(zip(*edges))
                edge_index = torch.LongTensor(edge_index)  # type: ignore
                # 2 x num_edges
                if edge_index.dim() != 2:  # type: ignore
                    raise AssertionError("edge index dimension should be 2")
                edge_mapping = edge_mapping.unsqueeze(0)  # 1 x num_unique_edges x dim
                self.world_graph = GeometricData(
                    x=x,
                    edge_index=edge_index,
                    edge_indicator=torch.tensor(edge_indicator),
                    edge_mapping=edge_mapping,
                )

        for key in self.queries:
            self.queries[key] = np.asarray(self.queries[key])

        for key in self.labels:
            self.labels[key] = np.asarray(self.labels[key])