in datasets.py [0:0]
def evaluate_dataset(self, data_args: DataTrainingArguments, model, device, batch_size: int, macro: bool = False) \
-> Dict[str, float]:
"""
Evaluate model on this dataset.
"""
results = Counter()
for example, trigger_output_sentence in self.generate_output_sentences(data_args, model, device, batch_size):
# phase 1: trigger prediction
trigger_output_format = self.output_format
predicted_triggers = \
trigger_output_format.run_inference(
example,
trigger_output_sentence,
entity_types=self.entity_types,
relation_types=self.relation_types,
)[0]
gt_triggers = set(trigger.to_tuple() for trigger in example.triggers)
correct_triggers = predicted_triggers & gt_triggers
predicted_triggers_notype = set()
gt_triggers_notype = set()
# trigger tuple format: (type, start, end) -- resetting all types to the same as 'TYPE'
for trig in predicted_triggers:
trig_list = list(trig)
trig_list[0] = 'TYPE'
predicted_triggers_notype.add(tuple(trig_list))
for trig in gt_triggers:
trig_list = list(trig)
trig_list[0] = 'TYPE'
gt_triggers_notype.add(tuple(trig_list))
correct_triggers_notype = predicted_triggers_notype & gt_triggers_notype
# phase 2: argument classification
all_gt_relations, all_predicted_relations, all_correct_relations = set(), set(), set()
for trigger in predicted_triggers:
example_argument_single_trigger = copy.deepcopy(example)
trigger_type = None
for trigger_type in self.entity_types:
if self.entity_types[trigger_type].natural == trigger[0]: break
example_argument_single_trigger.triggers = [
Entity(type=self.entity_types[trigger_type], start=trigger[1], end=trigger[2])]
argument_input_format = INPUT_FORMATS[self.argument_input_format]()
argument_output_format = OUTPUT_FORMATS[self.argument_output_format]()
example_input = argument_input_format.format_input(example_argument_single_trigger, multitask=True,
task_descriptor=ACE2005EventArgumentDataset.name)
example_input_ids = self.tokenizer.batch_encode_plus(
[example_input],
max_length=data_args.max_seq_length,
return_tensors='pt',
padding='max_length',
truncation=True
)
argument_output = model.generate(
example_input_ids['input_ids'].to(device),
max_length=data_args.max_output_seq_length_eval,
num_beams=data_args.num_beams,
)[0] # only one sample
argument_output_sentence = self.tokenizer.decode(argument_output, skip_special_tokens=True,
clean_up_tokenization_spaces=False)
gt_relations, predicted_relations, correct_relations = \
self.evaluate_argument(argument_output_format, example_argument_single_trigger, example,
argument_output_sentence)
all_gt_relations = all_gt_relations.union(gt_relations)
all_predicted_relations = all_predicted_relations.union(predicted_relations)
all_correct_relations = all_correct_relations.union(correct_relations)
all_predicted_relations_notype = set()
all_gt_relations_notype = set()
for rel in all_predicted_relations:
rel_list = list(rel)
rel_list[0] = 'TYPE'
all_predicted_relations_notype.add(tuple(rel_list))
for rel in all_gt_relations:
rel_list = list(rel)
rel_list[0] = 'TYPE'
all_gt_relations_notype.add(tuple(rel_list))
all_correct_relations_notype = all_predicted_relations_notype & all_gt_relations_notype
res = Counter({
'num_sentences': 1,
'gt_triggers': len(gt_triggers),
'predicted_triggers': len(predicted_triggers),
'correct_triggers': len(correct_triggers),
'correct_triggers_notype': len(correct_triggers_notype),
'predicted_relations': len(all_predicted_relations),
'gt_relations': len(all_gt_relations),
'correct_relations': len(all_correct_relations),
'correct_relations_notype': len(all_correct_relations_notype)
})
results += res
trigger_precision, trigger_recall, trigger_f1 = get_precision_recall_f1(
num_correct=results['correct_triggers'],
num_predicted=results['predicted_triggers'],
num_gt=results['gt_triggers'],
)
trigger_precision_notype, trigger_recall_notype, trigger_f1_notype = get_precision_recall_f1(
num_correct=results['correct_triggers_notype'],
num_predicted=results['predicted_triggers'],
num_gt=results['gt_triggers'],
)
relation_precision, relation_recall, relation_f1 = get_precision_recall_f1(
num_correct=results['correct_relations'],
num_predicted=results['predicted_relations'],
num_gt=results['gt_relations'],
)
relation_precision_notype, relation_recall_notype, relation_f1_notype = get_precision_recall_f1(
num_correct=results['correct_relations_notype'],
num_predicted=results['predicted_relations'],
num_gt=results['gt_relations'],
)
full_results = {
'relation_precision': relation_precision,
'relation_recall': relation_recall,
'relation_f1': relation_f1,
'relation_precision_notype': relation_precision_notype,
'relation_recall_notype': relation_recall_notype,
'relation_f1_notype': relation_f1_notype,
'trigger_precision': trigger_precision,
'trigger_recall': trigger_recall,
'trigger_f1': trigger_f1,
'trigger_precision_notype': trigger_precision_notype,
'trigger_recall_notype': trigger_recall_notype,
'trigger_f1_notype': trigger_f1_notype,
}
return full_results