in research/a2n/graph.py [0:0]
def read_graph(self, mode="train"):
"""Read the knowledge graph."""
logging.debug("Reading graph from %s", self._raw_kg_file)
with open(self._raw_kg_file, "r") as f:
kg_file = csv.reader(f, delimiter="\t")
skipped = 0
for line in kg_file:
e1 = line[0].strip()
if e1 not in self.entity_vocab:
if mode != "train":
skipped += 1
continue
self.entity_vocab[e1] = self.ent_vocab_size
self.ent_vocab_size += 1
e1 = self.entity_vocab[e1]
r = line[1].strip()
if r not in self.relation_vocab:
if mode != "train":
skipped += 1
continue
self.relation_vocab[r] = self.rel_vocab_size
self.rel_vocab_size += 1
if self.add_inverse_edge:
inv_r = self.inverse_relation_prefix + r
if inv_r not in self.relation_vocab:
self.relation_vocab[inv_r] = self.rel_vocab_size
self.rel_vocab_size += 1
inv_r = self.relation_vocab[inv_r]
r = self.relation_vocab[r]
e2 = line[2].strip()
if e2 not in self.entity_vocab:
if mode != "train":
skipped += 1
continue
self.entity_vocab[e2] = self.ent_vocab_size
self.ent_vocab_size += 1
e2 = self.entity_vocab[e2]
if e2 not in self.kg_data[e1]:
self.kg_data[e1][e2] = []
self.kg_data[e1][e2].append(r)
self.next_edges[e1].add((r, e2))
if self.add_inverse_edge:
if e1 not in self.kg_data[e2]:
self.kg_data[e2][e1] = []
self.kg_data[e2][e1].append(inv_r)
self.next_edges[e2].add((inv_r, e1))
self._num_edges += 1
if self.add_reverse_graph:
if e1 not in self.reverse_kg_data[e2]:
self.reverse_kg_data[e2][e1] = []
self.reverse_kg_data[e2][e1].append(r)
self.reverse_next_edges[e2].add((r, e1))
# if self.mode != 'train':
# self.tuple_store.append((e1, r, e2))
self._num_edges += 1
logging.info("Skipped %d tuples in mode %s", skipped, mode)
# if mode == "train" and self.no_op_relation not in self.relation_vocab:
# self.relation_vocab[self.no_op_relation] = self.rV
# self.rV += 1
if mode == "train" and self.entity_pad_token not in self.entity_vocab:
self.entity_vocab[self.entity_pad_token] = self.ent_vocab_size
self.ent_vocab_size += 1
if mode == "train" and self.relation_pad_token not in self.relation_vocab:
self.relation_vocab[self.relation_pad_token] = self.rel_vocab_size
self.rel_vocab_size += 1
self.ent_pad = self.entity_vocab[self.entity_pad_token]
self.rel_pad = self.relation_vocab[self.relation_pad_token]
# if self.mode != 'train':
# self.tuple_store = np.array(self.tuple_store)
# self.all_reachable_e2 = defaultdict(set)
if not self.max_kg_relations:
max_out = 0
for e1 in self.kg_data:
nout = 0
for e2 in self.kg_data[e1]:
nout += len(self.kg_data[e1][e2])
max_out = max(max_out, nout)
logging.info("Max outgoing rels kg: %d", max_out)
self.max_kg_relations = max_out