def eval_final()

in scripts/train_qa.py [0:0]


def eval_final(args, model, eval_dataloader, weight=0.8, gpu=True):
    """
    for final submission
    """
    model.eval()
    id2answer = collections.defaultdict(list)
    encode_times = []
    for batch in tqdm(eval_dataloader):
        batch_to_feed = move_to_cuda(batch["net_inputs"]) if gpu else batch["net_inputs"]
        batch_qids = batch["qids"]
        with torch.no_grad():
            start = time.time()
            outputs = model(batch_to_feed)
            encode_times.append(time.time() - start)

            scores = outputs["rank_score"]
            scores = scores.view(-1).tolist()

            if args.sp_pred:
                sp_scores = outputs["sp_score"]
                sp_scores = sp_scores.float().masked_fill(batch_to_feed["sent_offsets"].eq(0), float("-inf")).type_as(sp_scores)
                batch_sp_scores = sp_scores.sigmoid()

            # ans_type_predicted = torch.argmax(outputs["ans_type_logits"], dim=1).view(-1).tolist()
            outs = [outputs["start_logits"], outputs["end_logits"]]


        # answer prediction
        span_scores = outs[0][:, :, None] + outs[1][:, None]
        max_seq_len = span_scores.size(1)
        span_mask = np.tril(np.triu(np.ones((max_seq_len, max_seq_len)), 0), args.max_ans_len)
        span_mask = span_scores.data.new(max_seq_len, max_seq_len).copy_(torch.from_numpy(span_mask))
        span_scores_masked = span_scores.float().masked_fill((1 - span_mask[None].expand_as(span_scores)).bool(), -1e10).type_as(span_scores)
        start_position = span_scores_masked.max(dim=2)[0].max(dim=1)[1]
        end_position = span_scores_masked.max(dim=2)[1].gather(
            1, start_position.unsqueeze(1)).squeeze(1)
        answer_scores = span_scores_masked.max(dim=2)[0].max(dim=1)[0].tolist()
        para_offset = batch['para_offsets']
        start_position_ = list(
            np.array(start_position.tolist()) - np.array(para_offset))
        end_position_ = list(
            np.array(end_position.tolist()) - np.array(para_offset)) 

        for idx, qid in enumerate(batch_qids):
            rank_score = scores[idx]
            start = start_position_[idx]
            end = end_position_[idx]
            span_score = answer_scores[idx]
            tok_to_orig_index = batch['tok_to_orig_index'][idx]
            doc_tokens = batch['doc_tokens'][idx]
            wp_tokens = batch['wp_tokens'][idx]
            orig_doc_start = tok_to_orig_index[start]
            orig_doc_end = tok_to_orig_index[end]
            orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)]
            tok_tokens = wp_tokens[start:end+1]
            tok_text = " ".join(tok_tokens)
            tok_text = tok_text.replace(" ##", "")
            tok_text = tok_text.replace("##", "")
            tok_text = tok_text.strip()
            tok_text = " ".join(tok_text.split())
            orig_text = " ".join(orig_tokens)
            pred_str = get_final_text(tok_text, orig_text, do_lower_case=True, verbose_logging=False)

            chain_titles = [_["title"] for _ in batch["passages"][idx]]

            # get the sp sentences
            pred_sp = []
            if args.sp_pred:
                sp_score = batch_sp_scores[idx].tolist()
                passages = batch["passages"][idx]
                for passage, sent_offset in zip(passages, [0, len(passages[0]["sents"])]):
                    for idx, _ in enumerate(passage["sents"]):
                        try:
                            if sp_score[idx + sent_offset] > 0.5:
                                pred_sp.append([passage["title"], idx])
                        except:
                            # logger.info(f"sentence exceeds max lengths")
                            continue
            id2answer[qid].append({
                "pred_str": pred_str.strip(),
                "rank_score": rank_score,
                "span_score": span_score,
                "pred_sp": pred_sp,
                "chain_titles": chain_titles
            })
    lambda_ = weight
    results = collections.defaultdict(dict)
    for qid in id2answer.keys():
        ans_res = id2answer[qid]
        ans_res.sort(key=lambda x: lambda_ * x["rank_score"] + (1 - lambda_) * x["span_score"], reverse=True)
        top_pred = ans_res[0]["pred_str"]
        top_pred_sp = ans_res[0]["pred_sp"]

        results["answer"][qid] = top_pred
        results["sp"][qid] = top_pred_sp
        results["titles"][qid] = ans_res[0]["chain_titles"]


    if args.save_prediction != "":
        json.dump(results, open(f"{args.save_prediction}", "w"))

    return results