in experiments/codes/utils/data.py [0:0]
def load_data_pyg(self, rule_world):
"""
Load data in pytorch geometric
:return:
"""
# print("Loading data")
# rule_world = os.path.join(data_exp, rule_world)
for mode in self.graphs:
graph_file = os.path.join(rule_world, "{}.jsonl".format(mode))
graphs = []
self.meta_info[graph_file] = []
with open(graph_file, "r") as fp:
for line in fp:
graphs.append(json.loads(line))
for gi, gs in enumerate(graphs):
## Graph with Edge attributes
node2id = {}
edges = []
# edge_types = []
edge_attr = []
for (src, dst, rel) in gs["edges"]:
if src not in node2id:
node2id[src] = len(node2id)
if dst not in node2id:
node2id[dst] = len(node2id)
edges.append([node2id[src], node2id[dst]])
target = self.get_label2id(rel)
edge_attr.append(target)
# edge_types.append(rel)
# node_query_flags = torch.zeros(len(node2id))
(src, dst, rel) = gs["query"]
self.queries[mode].append((node2id[src], node2id[dst]))
# node_query_flags[node2id[src]] = 1
# node_query_flags[node2id[dst]] = 2
target = self.get_label2id(rel)
self.labels[mode].append(target)
self.label_set.add(target)
self.path_len[mode].append(len(gs["rules"]))
x = torch.arange(len(node2id)).unsqueeze(1)
edge_index = list(zip(*edges))
edge_index = torch.LongTensor(edge_index) # 2 x num_edges
assert edge_index.dim() == 2
edge_attr = torch.tensor(edge_attr)
# num_e = len(edges)
# edge_attr = torch.zeros(num_e, 1).long() # [num_edges, 1]
# for i, e in enumerate(edge_types):
# edge_attr[i][0] = self.get_label2id(e)
# nodes = list(set([p for x in edges for p in x]))
geo_data = GeometricData(
x=x,
edge_index=edge_index,
edge_attr=edge_attr,
y=torch.tensor([target]),
)
# edge_query = gs["query"]
# elem_edges = gs["edges"]
## Deprecated: Disabling edge graphs
## Graph with Edge as new nodes
## add the edges as new node : edge_id + len(nodes)
## s.t. later we can just subtract the len(nodes) from the graph
## There will be n - 1 new nodes for n nodes
# num_nodes = len(node2id)
# edges = []
# edge_ct = num_nodes
# if self.config.data.with_answer:
# # Adding answer edge in train mode
# if mode == "train":
# elem_edges.append(edge_query)
# edge_mapping = torch.zeros(
# (len(self.label2id), len(node2id) + len(elem_edges))
# ).long()
# for ei, elem in enumerate(elem_edges):
# edges.append([node2id[elem[0]], num_nodes + ei])
# edges.append([num_nodes + ei, node2id[elem[1]]])
# edge_mapping[self.get_label2id(elem[2])][num_nodes + ei] = 1
# edge_ct += 1
# x = torch.arange(edge_ct).unsqueeze(1)
# edge_index = list(zip(*edges))
# edge_index = torch.LongTensor(edge_index) # 2 x num_edges
# assert edge_index.dim() == 2
# num_e = len(edges)
# edge_indicator = torch.zeros_like(x)
# for node_id in range(edge_ct):
# if node_id not in node2id:
# edge_indicator[node_id][0] = 1
# edge_mapping = edge_mapping.unsqueeze(0) # 1 x num_unique_edges x dim
# # TODO: check if we need edge_graph at all, if not delete it
# geo_edge_data = GeometricData(
# x=x,
# edge_index=edge_index,
# edge_indicator=edge_indicator,
# edge_mapping=edge_mapping,
# )
# for nf in node_query_flags:
# qr = torch.zeros(1, 1, requires_grad=False, device=self.config.general.device)
# qr[0][0] = nf
# g.add_nodes(1, data={'q': qr})
# for edge in edges:
# rel = torch.zeros(1, 1, device=self.config.general.device)
# rel[0][0] = int(edge[2])
# self.label_set.add(edge[2])
# rel = rel.long()
# g.add_edge(edge[0], edge[1], data={'rel': rel})
self.graphs[mode].append(geo_data)
# self.edge_graphs[mode].append(geo_edge_data)
# print("{} Data loaded : {} graphs".format(mode, len(graphs)))
# load the meta graph
meta_graph_file = os.path.join(rule_world, "meta_graph.jsonl")
if os.path.exists(meta_graph_file):
with open(meta_graph_file, "r") as fp:
meta_graph = json.loads(fp.read())
edges = []
edge_types = []
elem_edges = meta_graph["edges"]
# populate edge ids
for elem in elem_edges:
if elem[0] not in node2id:
node2id[elem[0]] = len(node2id)
if elem[1] not in node2id:
node2id[elem[1]] = len(node2id)
edge_mapping = torch.zeros(
(len(self.label2id), len(node2id) + len(elem_edges))
).long()
num_nodes = len(node2id)
edge_ct = num_nodes
edge_indicator = [0 for ni in range(num_nodes)]
for ei, elem in enumerate(elem_edges):
edges.append([node2id[elem[0]], num_nodes + ei])
edges.append([num_nodes + ei, node2id[elem[1]]])
edge_mapping[self.get_label2id(elem[2])][num_nodes + ei] = 1
edge_ct += 1
# we are adding 1 to the edge indicator to keep the first position common for nodes
edge_indicator.append(self.get_label2id(elem[2]) + 1)
x = torch.arange(edge_ct).unsqueeze(1)
# torch.nn.init.xavier_uniform_(x, gain=1.414)
edge_index = list(zip(*edges))
edge_index = torch.LongTensor(edge_index) # 2 x num_edges
if edge_index.dim() != 2:
import ipdb
ipdb.set_trace()
num_e = len(edges)
edge_indicator = torch.tensor(edge_indicator)
# for node_id in range(edge_ct):
# if node_id not in node2id:
# edge_indicator[node_id][0] = 1
edge_mapping = edge_mapping.unsqueeze(0) # 1 x num_unique_edges x dim
self.world_graph = GeometricData(
x=x,
edge_index=edge_index,
edge_indicator=edge_indicator,
edge_mapping=edge_mapping,
)
for key in self.queries:
self.queries[key] = np.asarray(self.queries[key])
for key in self.labels:
self.labels[key] = np.asarray(self.labels[key])