in output_formats.py [0:0]
def run_inference(self, example: InputExample, output_sentence: str,
entity_types: Dict[str, EntityType] = None, relation_types: Dict[str, RelationType] = None) \
-> Tuple[set, set, bool, bool, bool, bool]:
"""
Process an output sentence to extract predicted entities and relations (among the given entity/relation types).
Return the predicted entities, predicted relations, and four booleans which describe if certain kinds of errors
occurred (wrong reconstruction of the sentence, label error, entity error, augmented language format error).
"""
label_error = False # whether the output sentence has at least one non-existing entity or relation type
entity_error = False # whether there is at least one relation pointing to a non-existing head entity
format_error = False # whether the augmented language format is invalid
if output_sentence.count(self.BEGIN_ENTITY_TOKEN) != output_sentence.count(self.END_ENTITY_TOKEN):
# the parentheses do not match
format_error = True
entity_types = set(entity_type.natural for entity_type in entity_types.values())
relation_types = set(relation_type.natural for relation_type in relation_types.values()) \
if relation_types is not None else {}
# parse output sentence
raw_predicted_entities, wrong_reconstruction = self.parse_output_sentence(example, output_sentence)
# update predicted entities with the positions in the original sentence
predicted_entities_by_name = defaultdict(list)
predicted_entities = set()
raw_predicted_relations = []
# process and filter entities
for entity_name, tags, start, end in raw_predicted_entities:
if len(tags) == 0 or len(tags[0]) > 1:
# we do not have a tag for the entity type
format_error = True
continue
entity_type = tags[0][0]
if entity_type in entity_types:
entity_tuple = (entity_type, start, end)
predicted_entities.add(entity_tuple)
predicted_entities_by_name[entity_name].append(entity_tuple)
# process tags to get relations
for tag in tags[1:]:
if len(tag) == 2:
relation_type, related_entity = tag
if relation_type in relation_types:
raw_predicted_relations.append((relation_type, entity_tuple, related_entity))
else:
label_error = True
else:
# the relation tag has the wrong length
format_error = True
else:
# the predicted entity type does not exist
label_error = True
predicted_relations = set()
for relation_type, entity_tuple, related_entity in raw_predicted_relations:
if related_entity in predicted_entities_by_name:
# look for the closest instance of the related entity
# (there could be many of them)
_, head_start, head_end = entity_tuple
candidates = sorted(
predicted_entities_by_name[related_entity],
key=lambda x:
min(abs(x[1] - head_end), abs(head_start - x[2]))
)
for candidate in candidates:
relation = (relation_type, entity_tuple, candidate)
if relation not in predicted_relations:
predicted_relations.add(relation)
break
else:
# cannot find the related entity in the sentence
entity_error = True
return predicted_entities, predicted_relations, wrong_reconstruction, label_error, entity_error, format_error