in code/run_squad.py [0:0]
def evaluate(args, model, device, eval_dataset, eval_dataloader,
eval_examples, eval_features, na_prob_thresh=1.0, pred_only=False):
all_results = []
model.eval()
for idx, (input_ids, input_mask, segment_ids, example_indices) in enumerate(eval_dataloader):
if pred_only and idx % 10 == 0:
logger.info("Running test: %d / %d" % (idx, len(eval_dataloader)))
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
with torch.no_grad():
batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
for i, example_index in enumerate(example_indices):
start_logits = batch_start_logits[i].detach().cpu().tolist()
end_logits = batch_end_logits[i].detach().cpu().tolist()
eval_feature = eval_features[example_index.item()]
unique_id = int(eval_feature.unique_id)
all_results.append(RawResult(unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
preds, nbest_preds, na_probs = \
make_predictions(eval_examples, eval_features, all_results,
args.n_best_size, args.max_answer_length,
args.do_lower_case, args.verbose_logging,
args.version_2_with_negative)
if pred_only:
if args.version_2_with_negative:
for k in preds:
if na_probs[k] > na_prob_thresh:
preds[k] = ''
return {}, preds, nbest_preds
if args.version_2_with_negative:
qid_to_has_ans = make_qid_to_has_ans(eval_dataset)
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(eval_dataset, preds)
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, na_prob_thresh)
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, na_prob_thresh)
result = make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
merge_eval(result, has_ans_eval, 'HasAns')
if no_ans_qids:
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
merge_eval(result, no_ans_eval, 'NoAns')
find_all_best_thresh(result, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
for k in preds:
if na_probs[k] > result['best_f1_thresh']:
preds[k] = ''
else:
exact_raw, f1_raw = get_raw_scores(eval_dataset, preds)
result = make_eval_dict(exact_raw, f1_raw)
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
return result, preds, nbest_preds