in scripts/reader/train.py [0:0]
def validate_official(args, data_loader, model, global_stats,
offsets, texts, answers):
"""Run one full official validation. Uses exact spans and same
exact match/F1 score computation as in the SQuAD script.
Extra arguments:
offsets: The character start/end indices for the tokens in each context.
texts: Map of qid --> raw text of examples context (matches offsets).
answers: Map of qid --> list of accepted answers.
"""
eval_time = utils.Timer()
f1 = utils.AverageMeter()
exact_match = utils.AverageMeter()
# Run through examples
examples = 0
for ex in data_loader:
ex_id, batch_size = ex[-1], ex[0].size(0)
pred_s, pred_e, _ = model.predict(ex)
for i in range(batch_size):
s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
prediction = texts[ex_id[i]][s_offset:e_offset]
# Compute metrics
ground_truths = answers[ex_id[i]]
exact_match.update(utils.metric_max_over_ground_truths(
utils.exact_match_score, prediction, ground_truths))
f1.update(utils.metric_max_over_ground_truths(
utils.f1_score, prediction, ground_truths))
examples += batch_size
logger.info('dev valid official: Epoch = %d | EM = %.2f | ' %
(global_stats['epoch'], exact_match.avg * 100) +
'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
(f1.avg * 100, examples, eval_time.time()))
return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}