in mdr/qa/qa_trainer.py [0:0]
def _eval(self) -> dict:
print("Start evaluation of the model", flush=True)
job_env = submitit.JobEnvironment()
args = self._train_cfg
eval_dataloader = self._test_loader
model = self._state.model
model.eval()
id2result = collections.defaultdict(list)
id2answer = collections.defaultdict(list)
id2gold = {}
id2goldsp = {}
for batch in tqdm(eval_dataloader):
batch_to_feed = move_to_cuda(batch["net_inputs"])
batch_qids = batch["qids"]
batch_labels = batch["net_inputs"]["label"].view(-1).tolist()
with torch.no_grad():
outputs = model(batch_to_feed)
scores = outputs["rank_score"]
scores = scores.view(-1).tolist()
sp_scores = outputs["sp_score"]
sp_scores = sp_scores.float().masked_fill(batch_to_feed["sent_offsets"].eq(0), float("-inf")).type_as(sp_scores)
batch_sp_scores = sp_scores.sigmoid()
# ans_type_predicted = torch.argmax(outputs["ans_type_logits"], dim=1).view(-1).tolist()
outs = [outputs["start_logits"], outputs["end_logits"]]
for qid, label, score in zip(batch_qids, batch_labels, scores):
id2result[qid].append((label, score))
# answer prediction
span_scores = outs[0][:, :, None] + outs[1][:, None]
max_seq_len = span_scores.size(1)
span_mask = np.tril(np.triu(np.ones((max_seq_len, max_seq_len)), 0), args.max_ans_len)
span_mask = span_scores.data.new(max_seq_len, max_seq_len).copy_(torch.from_numpy(span_mask))
span_scores_masked = span_scores.float().masked_fill((1 - span_mask[None].expand_as(span_scores)).bool(), -1e10).type_as(span_scores)
start_position = span_scores_masked.max(dim=2)[0].max(dim=1)[1]
end_position = span_scores_masked.max(dim=2)[1].gather(
1, start_position.unsqueeze(1)).squeeze(1)
answer_scores = span_scores_masked.max(dim=2)[0].max(dim=1)[0].tolist()
para_offset = batch['para_offsets']
start_position_ = list(
np.array(start_position.tolist()) - np.array(para_offset))
end_position_ = list(
np.array(end_position.tolist()) - np.array(para_offset))
for idx, qid in enumerate(batch_qids):
id2gold[qid] = batch["gold_answer"][idx]
id2goldsp[qid] = batch["sp_gold"][idx]
rank_score = scores[idx]
sp_score = batch_sp_scores[idx].tolist()
start = start_position_[idx]
end = end_position_[idx]
span_score = answer_scores[idx]
tok_to_orig_index = batch['tok_to_orig_index'][idx]
doc_tokens = batch['doc_tokens'][idx]
wp_tokens = batch['wp_tokens'][idx]
orig_doc_start = tok_to_orig_index[start]
orig_doc_end = tok_to_orig_index[end]
orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_tokens = wp_tokens[start:end+1]
tok_text = " ".join(tok_tokens)
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
pred_str = get_final_text(tok_text, orig_text, do_lower_case=True, verbose_logging=False)
pred_sp = []
passages = batch["passages"][idx]
for passage, sent_offset in zip(passages, [0, len(passages[0]["sents"])]):
for idx, _ in enumerate(passage["sents"]):
try:
if sp_score[idx + sent_offset] > 0.5:
pred_sp.append([passage["title"], idx])
except:
continue
id2answer[qid].append((pred_str.strip(), rank_score, span_score, pred_sp))
acc = []
for qid, res in id2result.items():
res.sort(key=lambda x: x[1], reverse=True)
acc.append(res[0][0] == 1)
print(f"evaluated {len(id2result)} questions...", flush=True)
print(f'chain ranking em: {np.mean(acc)}', flush=True)
best_em, best_f1, best_joint_em, best_joint_f1 = 0, 0, 0, 0
lambdas = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
for lambda_ in lambdas:
ems, f1s = [], []
sp_ems, sp_f1s = [], []
joint_ems, joint_f1s = [], []
for qid, res in id2result.items():
ans_res = id2answer[qid]
ans_res.sort(key=lambda x: lambda_ * x[1] + (1 - lambda_) * x[2], reverse=True)
top_pred = ans_res[0][0]
ems.append(exact_match_score(top_pred, id2gold[qid][0]))
f1, prec, recall = f1_score(top_pred, id2gold[qid][0])
f1s.append(f1)
top_pred_sp = ans_res[0][3]
metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0}
update_sp(metrics, top_pred_sp, id2goldsp[qid])
sp_ems.append(metrics['sp_em'])
sp_f1s.append(metrics['sp_f1'])
# joint metrics
joint_prec = prec * metrics["sp_prec"]
joint_recall = recall * metrics["sp_recall"]
if joint_prec + joint_recall > 0:
joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
else:
joint_f1 = 0
joint_em = ems[-1] * sp_ems[-1]
joint_ems.append(joint_em)
joint_f1s.append(joint_f1)
if best_joint_f1 < np.mean(joint_f1s):
best_joint_f1 = np.mean(joint_f1s)
best_joint_em = np.mean(joint_ems)
best_f1 = np.mean(f1s)
best_em = np.mean(ems)
print(f".......Using combination factor {lambda_}......", flush=True)
print(f'answer em: {np.mean(ems)}, count: {len(ems)}', flush=True)
print(f'answer f1: {np.mean(f1s)}, count: {len(f1s)}', flush=True)
print(f'sp em: {np.mean(sp_ems)}, count: {len(sp_ems)}', flush=True)
print(f'sp f1: {np.mean(sp_f1s)}, count: {len(sp_f1s)}', flush=True)
print(f'joint em: {np.mean(joint_ems)}, count: {len(joint_ems)}', flush=True)
print(f'joint f1: {np.mean(joint_f1s)}, count: {len(joint_f1s)}', flush=True)
print(f"Best joint EM/F1 from combination {best_em}/{best_f1}", flush=True)
model.train()
return {"em": best_em, "f1": best_f1, "joint_em": best_joint_em, "joint_f1": best_joint_f1}