def prefix_allowed_tokens_fn()

in genre/entity_linking.py [0:0]


    def prefix_allowed_tokens_fn(batch_id, sent):

        sent = sent.tolist()
        status = get_status(sent)
        sent_orig = sent_origs[batch_id]

        if status == "o":
            trie_out = get_trie_outside(sent, sent_orig)
        elif status == "m":
            trie_out = get_trie_mention(sent, sent_orig)
        elif status == "e":
            trie_out = get_trie_entity(sent, sent_orig)
            if trie_out == codes["EOS"]:
                trie_out = get_trie_outside(sent, sent_orig)
        else:
            raise RuntimeError

        return trie_out