def validate()

in train_extractive_reader.py [0:0]


    def validate(self):
        logger.info("Validation ...")
        cfg = self.cfg
        self.reader.eval()
        data_iterator = self.get_data_iterator(cfg.dev_files, cfg.train.dev_batch_size, False, shuffle=False)

        log_result_step = cfg.train.log_batch_step
        all_results = []

        eval_top_docs = cfg.eval_top_docs
        for i, samples_batch in enumerate(data_iterator.iterate_ds_data()):
            input = create_reader_input(
                self.tensorizer.get_pad_id(),
                samples_batch,
                cfg.passages_per_question_predict,
                cfg.encoder.sequence_length,
                cfg.max_n_answers,
                is_train=False,
                shuffle=False,
            )

            input = ReaderBatch(**move_to_device(input._asdict(), cfg.device))
            attn_mask = self.tensorizer.get_attn_mask(input.input_ids)

            with torch.no_grad():
                start_logits, end_logits, relevance_logits = self.reader(input.input_ids, attn_mask)

            batch_predictions = self._get_best_prediction(
                start_logits,
                end_logits,
                relevance_logits,
                samples_batch,
                passage_thresholds=eval_top_docs,
            )

            all_results.extend(batch_predictions)

            if (i + 1) % log_result_step == 0:
                logger.info("Eval step: %d ", i)

        ems = defaultdict(list)

        for q_predictions in all_results:
            gold_answers = q_predictions.gold_answers
            span_predictions = q_predictions.predictions  # {top docs threshold -> SpanPrediction()}
            for (n, span_prediction) in span_predictions.items():
                em_hit = max([exact_match_score(span_prediction.prediction_text, ga) for ga in gold_answers])
                ems[n].append(em_hit)
        em = 0
        for n in sorted(ems.keys()):
            em = np.mean(ems[n])
            logger.info("n=%d\tEM %.2f" % (n, em * 100))

        if cfg.prediction_results_file:
            self._save_predictions(cfg.prediction_results_file, all_results)

        return em