in train_extractive_reader.py [0:0]
def validate(self):
logger.info("Validation ...")
cfg = self.cfg
self.reader.eval()
data_iterator = self.get_data_iterator(cfg.dev_files, cfg.train.dev_batch_size, False, shuffle=False)
log_result_step = cfg.train.log_batch_step
all_results = []
eval_top_docs = cfg.eval_top_docs
for i, samples_batch in enumerate(data_iterator.iterate_ds_data()):
input = create_reader_input(
self.tensorizer.get_pad_id(),
samples_batch,
cfg.passages_per_question_predict,
cfg.encoder.sequence_length,
cfg.max_n_answers,
is_train=False,
shuffle=False,
)
input = ReaderBatch(**move_to_device(input._asdict(), cfg.device))
attn_mask = self.tensorizer.get_attn_mask(input.input_ids)
with torch.no_grad():
start_logits, end_logits, relevance_logits = self.reader(input.input_ids, attn_mask)
batch_predictions = self._get_best_prediction(
start_logits,
end_logits,
relevance_logits,
samples_batch,
passage_thresholds=eval_top_docs,
)
all_results.extend(batch_predictions)
if (i + 1) % log_result_step == 0:
logger.info("Eval step: %d ", i)
ems = defaultdict(list)
for q_predictions in all_results:
gold_answers = q_predictions.gold_answers
span_predictions = q_predictions.predictions # {top docs threshold -> SpanPrediction()}
for (n, span_prediction) in span_predictions.items():
em_hit = max([exact_match_score(span_prediction.prediction_text, ga) for ga in gold_answers])
ems[n].append(em_hit)
em = 0
for n in sorted(ems.keys()):
em = np.mean(ems[n])
logger.info("n=%d\tEM %.2f" % (n, em * 100))
if cfg.prediction_results_file:
self._save_predictions(cfg.prediction_results_file, all_results)
return em