in output_formats.py [0:0]
def run_inference(self, example: InputExample, output_sentence: str,
entity_types: Dict[str, EntityType] = None) -> Tuple[str, set]:
entity_types = set(entity_type.natural for entity_type in entity_types.values())
# parse output sentence
# get intent
for special_token in [self.BEGIN_INTENT_TOKEN, self.END_INTENT_TOKEN]:
output_sentence.replace(special_token, ' ' + special_token + ' ')
output_sentence_tokens = output_sentence.split()
if self.BEGIN_INTENT_TOKEN in output_sentence_tokens and \
self.END_INTENT_TOKEN in output_sentence_tokens:
intent = output_sentence.split(self.BEGIN_INTENT_TOKEN)[1].split(self.END_INTENT_TOKEN)[0].strip()
output_sentence = output_sentence.split(self.END_INTENT_TOKEN)[1] # remove intent from sentence
label_error = False # whether the output sentence has at least one non-existing entity or relation type
format_error = False # whether the augmented language format is invalid
if output_sentence.count(self.BEGIN_ENTITY_TOKEN) != output_sentence.count(self.END_ENTITY_TOKEN):
# the parentheses do not match
format_error = True
# parse output sentence
raw_predicted_entities, wrong_reconstruction = self.parse_output_sentence(example, output_sentence)
# update predicted entities with the positions in the original sentence
predicted_entities_by_name = defaultdict(list)
predicted_entities = set()
# process and filter entities
for entity_name, tags, start, end in raw_predicted_entities:
if len(tags) == 0 or len(tags[0]) > 1:
# we do not have a tag for the entity type
format_error = True
continue
entity_type = tags[0][0]
if entity_type in entity_types:
entity_tuple = (entity_type, start, end)
predicted_entities.add(entity_tuple)
else:
label_error = True
return intent, predicted_entities, wrong_reconstruction, label_error, format_error