def load_data_pyg()

in experiments/codes/utils/data.py [0:0]


    def load_data_pyg(self, rule_world):
        """
        Load data in pytorch geometric
        :return:
        """
        # print("Loading data")
        # rule_world = os.path.join(data_exp, rule_world)

        for mode in self.graphs:
            graph_file = os.path.join(rule_world, "{}.jsonl".format(mode))
            graphs = []
            self.meta_info[graph_file] = []
            with open(graph_file, "r") as fp:
                for line in fp:
                    graphs.append(json.loads(line))
            for gi, gs in enumerate(graphs):
                ## Graph with Edge attributes
                node2id = {}
                edges = []
                # edge_types = []
                edge_attr = []
                for (src, dst, rel) in gs["edges"]:
                    if src not in node2id:
                        node2id[src] = len(node2id)
                    if dst not in node2id:
                        node2id[dst] = len(node2id)
                    edges.append([node2id[src], node2id[dst]])
                    target = self.get_label2id(rel)
                    edge_attr.append(target)
                    # edge_types.append(rel)
                # node_query_flags = torch.zeros(len(node2id))
                (src, dst, rel) = gs["query"]
                self.queries[mode].append((node2id[src], node2id[dst]))
                # node_query_flags[node2id[src]] = 1
                # node_query_flags[node2id[dst]] = 2
                target = self.get_label2id(rel)
                self.labels[mode].append(target)
                self.label_set.add(target)
                self.path_len[mode].append(len(gs["rules"]))
                x = torch.arange(len(node2id)).unsqueeze(1)

                edge_index = list(zip(*edges))
                edge_index = torch.LongTensor(edge_index)  # 2 x num_edges
                assert edge_index.dim() == 2
                edge_attr = torch.tensor(edge_attr)

                # num_e = len(edges)
                # edge_attr = torch.zeros(num_e, 1).long()  # [num_edges, 1]
                # for i, e in enumerate(edge_types):
                # edge_attr[i][0] = self.get_label2id(e)
                # nodes = list(set([p for x in edges for p in x]))
                geo_data = GeometricData(
                    x=x,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=torch.tensor([target]),
                )

                # edge_query = gs["query"]
                # elem_edges = gs["edges"]
                ## Deprecated: Disabling edge graphs
                ## Graph with Edge as new nodes
                ## add the edges as new node : edge_id + len(nodes)
                ## s.t. later we can just subtract the len(nodes) from the graph
                ## There will be n - 1 new nodes for n nodes
                # num_nodes = len(node2id)
                # edges = []
                # edge_ct = num_nodes
                # if self.config.data.with_answer:
                #     # Adding answer edge in train mode
                #     if mode == "train":
                #         elem_edges.append(edge_query)
                # edge_mapping = torch.zeros(
                #     (len(self.label2id), len(node2id) + len(elem_edges))
                # ).long()
                # for ei, elem in enumerate(elem_edges):
                #     edges.append([node2id[elem[0]], num_nodes + ei])
                #     edges.append([num_nodes + ei, node2id[elem[1]]])
                #     edge_mapping[self.get_label2id(elem[2])][num_nodes + ei] = 1
                #     edge_ct += 1
                # x = torch.arange(edge_ct).unsqueeze(1)
                # edge_index = list(zip(*edges))
                # edge_index = torch.LongTensor(edge_index)  # 2 x num_edges
                # assert edge_index.dim() == 2
                # num_e = len(edges)
                # edge_indicator = torch.zeros_like(x)
                # for node_id in range(edge_ct):
                #     if node_id not in node2id:
                #         edge_indicator[node_id][0] = 1
                # edge_mapping = edge_mapping.unsqueeze(0)  # 1 x num_unique_edges x dim
                # # TODO: check if we need edge_graph at all, if not delete it
                # geo_edge_data = GeometricData(
                #     x=x,
                #     edge_index=edge_index,
                #     edge_indicator=edge_indicator,
                #     edge_mapping=edge_mapping,
                # )

                # for nf in node_query_flags:
                #     qr = torch.zeros(1, 1, requires_grad=False, device=self.config.general.device)
                #     qr[0][0] = nf
                #     g.add_nodes(1, data={'q': qr})

                # for edge in edges:
                #     rel = torch.zeros(1, 1, device=self.config.general.device)
                #     rel[0][0] = int(edge[2])
                #     self.label_set.add(edge[2])
                #     rel = rel.long()
                #     g.add_edge(edge[0], edge[1], data={'rel': rel})
                self.graphs[mode].append(geo_data)
                # self.edge_graphs[mode].append(geo_edge_data)
            # print("{} Data loaded : {} graphs".format(mode, len(graphs)))

        # load the meta graph
        meta_graph_file = os.path.join(rule_world, "meta_graph.jsonl")
        if os.path.exists(meta_graph_file):
            with open(meta_graph_file, "r") as fp:
                meta_graph = json.loads(fp.read())
                edges = []
                edge_types = []
                elem_edges = meta_graph["edges"]
                # populate edge ids
                for elem in elem_edges:
                    if elem[0] not in node2id:
                        node2id[elem[0]] = len(node2id)
                    if elem[1] not in node2id:
                        node2id[elem[1]] = len(node2id)
                edge_mapping = torch.zeros(
                    (len(self.label2id), len(node2id) + len(elem_edges))
                ).long()
                num_nodes = len(node2id)
                edge_ct = num_nodes
                edge_indicator = [0 for ni in range(num_nodes)]
                for ei, elem in enumerate(elem_edges):
                    edges.append([node2id[elem[0]], num_nodes + ei])
                    edges.append([num_nodes + ei, node2id[elem[1]]])
                    edge_mapping[self.get_label2id(elem[2])][num_nodes + ei] = 1
                    edge_ct += 1
                    # we are adding 1 to the edge indicator to keep the first position common for nodes
                    edge_indicator.append(self.get_label2id(elem[2]) + 1)
                x = torch.arange(edge_ct).unsqueeze(1)
                # torch.nn.init.xavier_uniform_(x, gain=1.414)
                edge_index = list(zip(*edges))
                edge_index = torch.LongTensor(edge_index)  # 2 x num_edges
                if edge_index.dim() != 2:
                    import ipdb

                    ipdb.set_trace()
                num_e = len(edges)
                edge_indicator = torch.tensor(edge_indicator)
                # for node_id in range(edge_ct):
                #     if node_id not in node2id:
                #         edge_indicator[node_id][0] = 1
                edge_mapping = edge_mapping.unsqueeze(0)  # 1 x num_unique_edges x dim
                self.world_graph = GeometricData(
                    x=x,
                    edge_index=edge_index,
                    edge_indicator=edge_indicator,
                    edge_mapping=edge_mapping,
                )

        for key in self.queries:
            self.queries[key] = np.asarray(self.queries[key])

        for key in self.labels:
            self.labels[key] = np.asarray(self.labels[key])