in scripts/train_qa.py [0:0]
def eval_final(args, model, eval_dataloader, weight=0.8, gpu=True):
"""
for final submission
"""
model.eval()
id2answer = collections.defaultdict(list)
encode_times = []
for batch in tqdm(eval_dataloader):
batch_to_feed = move_to_cuda(batch["net_inputs"]) if gpu else batch["net_inputs"]
batch_qids = batch["qids"]
with torch.no_grad():
start = time.time()
outputs = model(batch_to_feed)
encode_times.append(time.time() - start)
scores = outputs["rank_score"]
scores = scores.view(-1).tolist()
if args.sp_pred:
sp_scores = outputs["sp_score"]
sp_scores = sp_scores.float().masked_fill(batch_to_feed["sent_offsets"].eq(0), float("-inf")).type_as(sp_scores)
batch_sp_scores = sp_scores.sigmoid()
# ans_type_predicted = torch.argmax(outputs["ans_type_logits"], dim=1).view(-1).tolist()
outs = [outputs["start_logits"], outputs["end_logits"]]
# answer prediction
span_scores = outs[0][:, :, None] + outs[1][:, None]
max_seq_len = span_scores.size(1)
span_mask = np.tril(np.triu(np.ones((max_seq_len, max_seq_len)), 0), args.max_ans_len)
span_mask = span_scores.data.new(max_seq_len, max_seq_len).copy_(torch.from_numpy(span_mask))
span_scores_masked = span_scores.float().masked_fill((1 - span_mask[None].expand_as(span_scores)).bool(), -1e10).type_as(span_scores)
start_position = span_scores_masked.max(dim=2)[0].max(dim=1)[1]
end_position = span_scores_masked.max(dim=2)[1].gather(
1, start_position.unsqueeze(1)).squeeze(1)
answer_scores = span_scores_masked.max(dim=2)[0].max(dim=1)[0].tolist()
para_offset = batch['para_offsets']
start_position_ = list(
np.array(start_position.tolist()) - np.array(para_offset))
end_position_ = list(
np.array(end_position.tolist()) - np.array(para_offset))
for idx, qid in enumerate(batch_qids):
rank_score = scores[idx]
start = start_position_[idx]
end = end_position_[idx]
span_score = answer_scores[idx]
tok_to_orig_index = batch['tok_to_orig_index'][idx]
doc_tokens = batch['doc_tokens'][idx]
wp_tokens = batch['wp_tokens'][idx]
orig_doc_start = tok_to_orig_index[start]
orig_doc_end = tok_to_orig_index[end]
orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_tokens = wp_tokens[start:end+1]
tok_text = " ".join(tok_tokens)
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
pred_str = get_final_text(tok_text, orig_text, do_lower_case=True, verbose_logging=False)
chain_titles = [_["title"] for _ in batch["passages"][idx]]
# get the sp sentences
pred_sp = []
if args.sp_pred:
sp_score = batch_sp_scores[idx].tolist()
passages = batch["passages"][idx]
for passage, sent_offset in zip(passages, [0, len(passages[0]["sents"])]):
for idx, _ in enumerate(passage["sents"]):
try:
if sp_score[idx + sent_offset] > 0.5:
pred_sp.append([passage["title"], idx])
except:
# logger.info(f"sentence exceeds max lengths")
continue
id2answer[qid].append({
"pred_str": pred_str.strip(),
"rank_score": rank_score,
"span_score": span_score,
"pred_sp": pred_sp,
"chain_titles": chain_titles
})
lambda_ = weight
results = collections.defaultdict(dict)
for qid in id2answer.keys():
ans_res = id2answer[qid]
ans_res.sort(key=lambda x: lambda_ * x["rank_score"] + (1 - lambda_) * x["span_score"], reverse=True)
top_pred = ans_res[0]["pred_str"]
top_pred_sp = ans_res[0]["pred_sp"]
results["answer"][qid] = top_pred
results["sp"][qid] = top_pred_sp
results["titles"][qid] = ans_res[0]["chain_titles"]
if args.save_prediction != "":
json.dump(results, open(f"{args.save_prediction}", "w"))
return results