def evaluate_dataset()

in datasets.py [0:0]


    def evaluate_dataset(self, data_args: DataTrainingArguments, model, device, batch_size: int, macro: bool = False) \
            -> Dict[str, float]:
        """
        Evaluate model on this dataset.
        """
        results = Counter()

        for example, trigger_output_sentence in self.generate_output_sentences(data_args, model, device, batch_size):
            # phase 1: trigger prediction
            trigger_output_format = self.output_format
            predicted_triggers = \
                trigger_output_format.run_inference(
                    example,
                    trigger_output_sentence,
                    entity_types=self.entity_types,
                    relation_types=self.relation_types,
                )[0]
            gt_triggers = set(trigger.to_tuple() for trigger in example.triggers)
            correct_triggers = predicted_triggers & gt_triggers
            predicted_triggers_notype = set()
            gt_triggers_notype = set()
            # trigger tuple format: (type, start, end) -- resetting all types to the same as 'TYPE'
            for trig in predicted_triggers:
                trig_list = list(trig)
                trig_list[0] = 'TYPE'
                predicted_triggers_notype.add(tuple(trig_list))
            for trig in gt_triggers:
                trig_list = list(trig)
                trig_list[0] = 'TYPE'
                gt_triggers_notype.add(tuple(trig_list))
            correct_triggers_notype = predicted_triggers_notype & gt_triggers_notype
            
            # phase 2: argument classification
            all_gt_relations, all_predicted_relations, all_correct_relations = set(), set(), set()
            for trigger in predicted_triggers:
                example_argument_single_trigger = copy.deepcopy(example)
                trigger_type = None
                for trigger_type in self.entity_types:
                    if self.entity_types[trigger_type].natural == trigger[0]: break
                example_argument_single_trigger.triggers = [
                    Entity(type=self.entity_types[trigger_type], start=trigger[1], end=trigger[2])]

                argument_input_format = INPUT_FORMATS[self.argument_input_format]()
                argument_output_format = OUTPUT_FORMATS[self.argument_output_format]()
                example_input = argument_input_format.format_input(example_argument_single_trigger, multitask=True,
                                                                   task_descriptor=ACE2005EventArgumentDataset.name)
                example_input_ids = self.tokenizer.batch_encode_plus(
                    [example_input],
                    max_length=data_args.max_seq_length,
                    return_tensors='pt',
                    padding='max_length',
                    truncation=True
                )
                argument_output = model.generate(
                    example_input_ids['input_ids'].to(device),
                    max_length=data_args.max_output_seq_length_eval,
                    num_beams=data_args.num_beams,
                )[0]  # only one sample
                argument_output_sentence = self.tokenizer.decode(argument_output, skip_special_tokens=True,
                                                                 clean_up_tokenization_spaces=False)

                gt_relations, predicted_relations, correct_relations = \
                    self.evaluate_argument(argument_output_format, example_argument_single_trigger, example,
                                           argument_output_sentence)
                all_gt_relations = all_gt_relations.union(gt_relations)
                all_predicted_relations = all_predicted_relations.union(predicted_relations)
                all_correct_relations = all_correct_relations.union(correct_relations)
                
            all_predicted_relations_notype = set()
            all_gt_relations_notype = set()
            for rel in all_predicted_relations:
                rel_list = list(rel)
                rel_list[0] = 'TYPE'
                all_predicted_relations_notype.add(tuple(rel_list))
            for rel in all_gt_relations:
                rel_list = list(rel)
                rel_list[0] = 'TYPE'
                all_gt_relations_notype.add(tuple(rel_list))

            all_correct_relations_notype = all_predicted_relations_notype & all_gt_relations_notype
            res = Counter({
                'num_sentences': 1,
                'gt_triggers': len(gt_triggers),
                'predicted_triggers': len(predicted_triggers),
                'correct_triggers': len(correct_triggers),
                'correct_triggers_notype': len(correct_triggers_notype),
                'predicted_relations': len(all_predicted_relations),
                'gt_relations': len(all_gt_relations),
                'correct_relations': len(all_correct_relations),
                'correct_relations_notype': len(all_correct_relations_notype)
            })
            
            results += res
        
        trigger_precision, trigger_recall, trigger_f1 = get_precision_recall_f1(
            num_correct=results['correct_triggers'],
            num_predicted=results['predicted_triggers'],
            num_gt=results['gt_triggers'],
        )
        trigger_precision_notype, trigger_recall_notype, trigger_f1_notype = get_precision_recall_f1(
            num_correct=results['correct_triggers_notype'],
            num_predicted=results['predicted_triggers'],
            num_gt=results['gt_triggers'],
        )
        relation_precision, relation_recall, relation_f1 = get_precision_recall_f1(
            num_correct=results['correct_relations'],
            num_predicted=results['predicted_relations'],
            num_gt=results['gt_relations'],
        )
        relation_precision_notype, relation_recall_notype, relation_f1_notype = get_precision_recall_f1(
            num_correct=results['correct_relations_notype'],
            num_predicted=results['predicted_relations'],
            num_gt=results['gt_relations'],
        )
        
        full_results = {
            'relation_precision': relation_precision,
            'relation_recall': relation_recall,
            'relation_f1': relation_f1,
            'relation_precision_notype': relation_precision_notype,
            'relation_recall_notype': relation_recall_notype,
            'relation_f1_notype': relation_f1_notype,
            'trigger_precision': trigger_precision,
            'trigger_recall': trigger_recall,
            'trigger_f1': trigger_f1,
            'trigger_precision_notype': trigger_precision_notype,
            'trigger_recall_notype': trigger_recall_notype,
            'trigger_f1_notype': trigger_f1_notype,
        }
        
        return full_results