in Models/exprsynth/nagdecoder.py [0:0]
def finalise_metadata(self, raw_metadata_list: List[Dict[str, Any]], final_metadata: Dict[str, Any]) -> None:
# First, merge all needed information:
merged_token_counter = Counter()
merged_literal_counters = {literal_kind: Counter() for literal_kind in LITERAL_NONTERMINALS}
merged_production_vocab = defaultdict(set)
for raw_metadata in raw_metadata_list:
merged_token_counter += raw_metadata['eg_token_counter']
for literal_kind in LITERAL_NONTERMINALS:
merged_literal_counters[literal_kind] += raw_metadata['eg_literal_counters'][literal_kind]
for lhs, rhs_options in raw_metadata['eg_production_vocab'].items():
merged_production_vocab[lhs].update(rhs_options)
final_metadata['eg_token_vocab'] = \
Vocabulary.create_vocabulary(merged_token_counter,
max_size=self.hyperparameters['eg_token_vocab_size'])
final_metadata["eg_literal_vocabs"] = {}
for literal_kind in LITERAL_NONTERMINALS:
final_metadata["eg_literal_vocabs"][literal_kind] = \
Vocabulary.create_vocabulary(merged_literal_counters[literal_kind],
count_threshold=0,
max_size=self.hyperparameters['eg_literal_vocab_size'])
next_production_id = 0
eg_production_vocab = defaultdict(dict)
next_edge_label_id = 0
eg_edge_label_vocab = defaultdict(dict)
for lhs, rhs_options in sorted(merged_production_vocab.items(), key=lambda t: t[0]):
final_metadata['eg_token_vocab'].add_or_get_id(lhs)
for rhs in sorted(rhs_options):
production_id = eg_production_vocab[lhs].get(rhs)
if production_id is None:
production_id = next_production_id
eg_production_vocab[lhs][rhs] = production_id
next_production_id += 1
for (rhs_symbol_index, symbol) in enumerate(rhs):
final_metadata['eg_token_vocab'].add_or_get_id(symbol)
eg_edge_label_vocab[(production_id, rhs_symbol_index)] = next_edge_label_id
next_edge_label_id += 1
final_metadata["eg_production_vocab"] = eg_production_vocab
final_metadata["eg_edge_label_vocab"] = eg_edge_label_vocab
final_metadata['eg_production_num'] = next_production_id
self.train_log("Imputed grammar:")
for lhs, rhs_options in eg_production_vocab.items():
for rhs, idx in sorted(rhs_options.items(), key=lambda v: v[1]):
self.train_log(" %s -[%02i]-> %s" % (str(lhs), idx, " ".join(rhs)))
self.train_log("Known literals:")
for literal_kind in LITERAL_NONTERMINALS:
self.train_log(" %s: %s" % (literal_kind, sorted(final_metadata['eg_literal_vocabs'][literal_kind].token_to_id.keys())))