in genre/entity_linking.py [0:0]
def prefix_allowed_tokens_fn(batch_id, sent):
sent = sent.tolist()
status = get_status(sent)
sent_orig = sent_origs[batch_id]
if status == "o":
trie_out = get_trie_outside(sent, sent_orig)
elif status == "m":
trie_out = get_trie_mention(sent, sent_orig)
elif status == "e":
trie_out = get_trie_entity(sent, sent_orig)
if trie_out == codes["EOS"]:
trie_out = get_trie_outside(sent, sent_orig)
else:
raise RuntimeError
return trie_out