in experiments/codes/utils/data.py [0:0]
def load_data_dgl(self, rule_world):
"""
load data from a particular rule world
in DGL
:param rule_world:
:return:
"""
print("Loading data")
# rule_world = os.path.join(data_exp, rule_world)
for mode in self.graphs:
mode_folder = os.path.join(rule_world, mode)
gs = glob.glob(mode_folder + "/*.txt")
g_query = [f for f in gs if "_query" in f]
gs = [f for f in gs if f not in g_query]
for gi, gl in enumerate(gs):
graph_id = gl.split("/")[-1].split(".txt")[0]
g = dgl.DGLGraph()
node2id = {}
edges = []
with open(gl, "r") as fp:
for line in fp:
elem = line.rstrip().split(" ")
if elem[0] not in node2id:
node2id[elem[0]] = len(node2id)
if elem[1] not in node2id:
node2id[elem[1]] = len(node2id)
edges.append([node2id[elem[0]], node2id[elem[1]], elem[2]])
node_query_flags = torch.zeros(len(node2id))
with open(
os.path.join(mode_folder, "{}_query.txt".format(graph_id)), "r"
) as fp:
lines = fp.readlines()
elem = lines[0].rstrip().split(" ")
self.queries[mode].append((node2id[elem[0]], node2id[elem[1]]))
node_query_flags[node2id[elem[0]]] = 1
node_query_flags[node2id[elem[1]]] = 2
self.labels[mode].append(int(elem[2]))
self.label_set.update(elem[2])
for nf in node_query_flags:
qr = torch.zeros(
1, 1, requires_grad=False, device=self.config.general.device
)
qr[0][0] = nf
g.add_nodes(1, data={"q": qr})
for edge in edges:
rel = torch.zeros(1, 1, device=self.config.general.device)
rel[0][0] = int(edge[2])
self.label_set.add(edge[2])
rel = rel.long()
g.add_edge(edge[0], edge[1], data={"rel": rel})
self.graphs[mode].append(g)
print("{} Data loaded : {} graphs".format(mode, len(gs)))
for key in self.graphs:
self.graphs[key] = np.asarray(self.graphs[key])
for key in self.queries:
self.queries[key] = np.asarray(self.queries[key])
for key in self.labels:
self.labels[key] = np.asarray(self.labels[key])