in datasets.py [0:0]
def evaluate_example(self, example: InputExample, output_sentence: str, model=None, tokenizer=None) -> Counter:
"""
Evaluate an output sentence on a single example of this dataset.
"""
# extract entities and relations from output sentence
res = self.output_format.run_inference(
example,
output_sentence,
entity_types=self.entity_types,
relation_types=self.relation_types,
)
predicted_entities, predicted_relations = res[:2]
if len(res) == 6:
# the output format provides information about errors
wrong_reconstruction, label_error, entity_error, format_error = res[2:]
else:
# in case the output format does not provide information about errors
wrong_reconstruction = label_error = entity_error = format_error = False
predicted_entities_no_type = set([entity[1:] for entity in predicted_entities])
# load ground truth entities
gt_entities = set(entity.to_tuple() for entity in example.entities)
gt_entities_no_type = set([entity[1:] for entity in gt_entities])
# compute correct entities
correct_entities = predicted_entities & gt_entities
correct_entities_no_type = gt_entities_no_type & predicted_entities_no_type
# load ground truth relations
gt_relations = set(relation.to_tuple() for relation in example.relations)
# compute correct relations
correct_relations = predicted_relations & gt_relations
assert len(correct_entities) <= len(predicted_entities)
assert len(correct_entities) <= len(gt_entities)
assert len(correct_entities_no_type) <= len(predicted_entities_no_type)
assert len(correct_entities_no_type) <= len(gt_entities_no_type)
assert len(correct_relations) <= len(predicted_relations)
assert len(correct_relations) <= len(gt_relations)
res = Counter({
'num_sentences': 1,
'wrong_reconstructions': 1 if wrong_reconstruction else 0,
'label_error': 1 if label_error else 0,
'entity_error': 1 if entity_error else 0,
'format_error': 1 if format_error else 0,
'gt_entities': len(gt_entities),
'predicted_entities': len(predicted_entities),
'correct_entities': len(correct_entities),
'gt_entities_no_type': len(gt_entities_no_type),
'predicted_entities_no_type': len(predicted_entities_no_type),
'correct_entities_no_type': len(correct_entities_no_type),
'gt_relations': len(gt_relations),
'predicted_relations': len(predicted_relations),
'correct_relations': len(correct_relations),
})
# add information about each entity/relation type so that we can compute the macro-F1 scores
if self.entity_types is not None:
for entity_type in self.entity_types.values():
predicted = set(entity for entity in predicted_entities if entity[0] == entity_type.natural)
gt = set(entity for entity in gt_entities if entity[0] == entity_type.natural)
correct = predicted & gt
res['predicted_entities', entity_type.natural] = len(predicted)
res['gt_entities', entity_type.natural] = len(gt)
res['correct_entities', entity_type.natural] = len(correct)
if self.relation_types is not None:
for relation_type in self.relation_types.values():
predicted = set(relation for relation in predicted_relations if relation[0] == relation_type.natural)
gt = set(relation for relation in gt_relations if relation[0] == relation_type.natural)
correct = predicted & gt
res['predicted_relations', relation_type.natural] = len(predicted)
res['gt_relations', relation_type.natural] = len(gt)
res['correct_relations', relation_type.natural] = len(correct)
return res