def load_data_dgl()

in experiments/codes/utils/data.py [0:0]


    def load_data_dgl(self, rule_world):
        """
        load data from a particular rule world
        in DGL
        :param rule_world:
        :return:
        """
        print("Loading data")
        # rule_world = os.path.join(data_exp, rule_world)
        for mode in self.graphs:
            mode_folder = os.path.join(rule_world, mode)
            gs = glob.glob(mode_folder + "/*.txt")
            g_query = [f for f in gs if "_query" in f]
            gs = [f for f in gs if f not in g_query]
            for gi, gl in enumerate(gs):
                graph_id = gl.split("/")[-1].split(".txt")[0]
                g = dgl.DGLGraph()
                node2id = {}
                edges = []
                with open(gl, "r") as fp:
                    for line in fp:
                        elem = line.rstrip().split(" ")
                        if elem[0] not in node2id:
                            node2id[elem[0]] = len(node2id)
                        if elem[1] not in node2id:
                            node2id[elem[1]] = len(node2id)
                        edges.append([node2id[elem[0]], node2id[elem[1]], elem[2]])
                node_query_flags = torch.zeros(len(node2id))
                with open(
                    os.path.join(mode_folder, "{}_query.txt".format(graph_id)), "r"
                ) as fp:
                    lines = fp.readlines()
                    elem = lines[0].rstrip().split(" ")
                    self.queries[mode].append((node2id[elem[0]], node2id[elem[1]]))
                    node_query_flags[node2id[elem[0]]] = 1
                    node_query_flags[node2id[elem[1]]] = 2
                    self.labels[mode].append(int(elem[2]))
                    self.label_set.update(elem[2])
                for nf in node_query_flags:
                    qr = torch.zeros(
                        1, 1, requires_grad=False, device=self.config.general.device
                    )
                    qr[0][0] = nf
                    g.add_nodes(1, data={"q": qr})
                for edge in edges:
                    rel = torch.zeros(1, 1, device=self.config.general.device)
                    rel[0][0] = int(edge[2])
                    self.label_set.add(edge[2])
                    rel = rel.long()
                    g.add_edge(edge[0], edge[1], data={"rel": rel})
                self.graphs[mode].append(g)
            print("{} Data loaded : {} graphs".format(mode, len(gs)))

        for key in self.graphs:
            self.graphs[key] = np.asarray(self.graphs[key])

        for key in self.queries:
            self.queries[key] = np.asarray(self.queries[key])

        for key in self.labels:
            self.labels[key] = np.asarray(self.labels[key])