def finalise_metadata()

in Models/exprsynth/nagdecoder.py [0:0]


    def finalise_metadata(self, raw_metadata_list: List[Dict[str, Any]], final_metadata: Dict[str, Any]) -> None:
        # First, merge all needed information:
        merged_token_counter = Counter()
        merged_literal_counters = {literal_kind: Counter() for literal_kind in LITERAL_NONTERMINALS}
        merged_production_vocab = defaultdict(set)
        for raw_metadata in raw_metadata_list:
            merged_token_counter += raw_metadata['eg_token_counter']
            for literal_kind in LITERAL_NONTERMINALS:
                merged_literal_counters[literal_kind] += raw_metadata['eg_literal_counters'][literal_kind]
            for lhs, rhs_options in raw_metadata['eg_production_vocab'].items():
                merged_production_vocab[lhs].update(rhs_options)

        final_metadata['eg_token_vocab'] = \
            Vocabulary.create_vocabulary(merged_token_counter,
                                         max_size=self.hyperparameters['eg_token_vocab_size'])

        final_metadata["eg_literal_vocabs"] = {}
        for literal_kind in LITERAL_NONTERMINALS:
            final_metadata["eg_literal_vocabs"][literal_kind] = \
                Vocabulary.create_vocabulary(merged_literal_counters[literal_kind],
                                             count_threshold=0,
                                             max_size=self.hyperparameters['eg_literal_vocab_size'])

        next_production_id = 0
        eg_production_vocab = defaultdict(dict)
        next_edge_label_id = 0
        eg_edge_label_vocab = defaultdict(dict)
        for lhs, rhs_options in sorted(merged_production_vocab.items(), key=lambda t: t[0]):
            final_metadata['eg_token_vocab'].add_or_get_id(lhs)
            for rhs in sorted(rhs_options):
                production_id = eg_production_vocab[lhs].get(rhs)
                if production_id is None:
                    production_id = next_production_id
                    eg_production_vocab[lhs][rhs] = production_id
                    next_production_id += 1
                for (rhs_symbol_index, symbol) in enumerate(rhs):
                    final_metadata['eg_token_vocab'].add_or_get_id(symbol)
                    eg_edge_label_vocab[(production_id, rhs_symbol_index)] = next_edge_label_id
                    next_edge_label_id += 1

        final_metadata["eg_production_vocab"] = eg_production_vocab
        final_metadata["eg_edge_label_vocab"] = eg_edge_label_vocab
        final_metadata['eg_production_num'] = next_production_id

        self.train_log("Imputed grammar:")
        for lhs, rhs_options in eg_production_vocab.items():
            for rhs, idx in sorted(rhs_options.items(), key=lambda v: v[1]):
                self.train_log("  %s  -[%02i]->  %s" % (str(lhs), idx, " ".join(rhs)))
        self.train_log("Known literals:")
        for literal_kind in LITERAL_NONTERMINALS:
            self.train_log("  %s: %s" % (literal_kind, sorted(final_metadata['eg_literal_vocabs'][literal_kind].token_to_id.keys())))