def predict()

in scripts/train_qa.py [0:0]


def predict(args, model, eval_dataloader, logger, fixed_thresh=None):
    model.eval()
    id2result = collections.defaultdict(list)
    id2answer = collections.defaultdict(list)
    id2gold = {}
    id2goldsp = {}
    for batch in tqdm(eval_dataloader):
        batch_to_feed = move_to_cuda(batch["net_inputs"])
        batch_qids = batch["qids"]
        batch_labels = batch["net_inputs"]["label"].view(-1).tolist()
        with torch.no_grad():
            outputs = model(batch_to_feed)
            scores = outputs["rank_score"]
            scores = scores.view(-1).tolist()
            if args.sp_pred:
                sp_scores = outputs["sp_score"]
                sp_scores = sp_scores.float().masked_fill(batch_to_feed["sent_offsets"].eq(0), float("-inf")).type_as(sp_scores)
                batch_sp_scores = sp_scores.sigmoid()

            # ans_type_predicted = torch.argmax(outputs["ans_type_logits"], dim=1).view(-1).tolist()
            outs = [outputs["start_logits"], outputs["end_logits"]]
        for qid, label, score in zip(batch_qids, batch_labels, scores):
            id2result[qid].append((label, score))

        # answer prediction
        span_scores = outs[0][:, :, None] + outs[1][:, None]
        max_seq_len = span_scores.size(1)
        span_mask = np.tril(np.triu(np.ones((max_seq_len, max_seq_len)), 0), args.max_ans_len)
        span_mask = span_scores.data.new(max_seq_len, max_seq_len).copy_(torch.from_numpy(span_mask))
        span_scores_masked = span_scores.float().masked_fill((1 - span_mask[None].expand_as(span_scores)).bool(), -1e10).type_as(span_scores)
        start_position = span_scores_masked.max(dim=2)[0].max(dim=1)[1]
        end_position = span_scores_masked.max(dim=2)[1].gather(
            1, start_position.unsqueeze(1)).squeeze(1)
        answer_scores = span_scores_masked.max(dim=2)[0].max(dim=1)[0].tolist()
        para_offset = batch['para_offsets']
        start_position_ = list(
            np.array(start_position.tolist()) - np.array(para_offset))
        end_position_ = list(
            np.array(end_position.tolist()) - np.array(para_offset)) 

        for idx, qid in enumerate(batch_qids):
            id2gold[qid] = batch["gold_answer"][idx]
            id2goldsp[qid] = batch["sp_gold"][idx]

            rank_score = scores[idx]
            start = start_position_[idx]
            end = end_position_[idx]
            span_score = answer_scores[idx]
            
            tok_to_orig_index = batch['tok_to_orig_index'][idx]
            doc_tokens = batch['doc_tokens'][idx]
            wp_tokens = batch['wp_tokens'][idx]
            orig_doc_start = tok_to_orig_index[start]
            orig_doc_end = tok_to_orig_index[end]
            orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)]
            tok_tokens = wp_tokens[start:end+1]
            tok_text = " ".join(tok_tokens)
            tok_text = tok_text.replace(" ##", "")
            tok_text = tok_text.replace("##", "")
            tok_text = tok_text.strip()
            tok_text = " ".join(tok_text.split())
            orig_text = " ".join(orig_tokens)
            pred_str = get_final_text(tok_text, orig_text, do_lower_case=True, verbose_logging=False)

            # get the sp sentences
            pred_sp = []
            if args.sp_pred:
                sp_score = batch_sp_scores[idx].tolist()
                passages = batch["passages"][idx]
                for passage, sent_offset in zip(passages, [0, len(passages[0]["sents"])]):
                    for idx, _ in enumerate(passage["sents"]):
                        try:
                            if sp_score[idx + sent_offset] >= 0.5:
                                pred_sp.append([passage["title"], idx])
                        except:
                            # logger.info(f"sentence exceeds max lengths")
                            continue
            id2answer[qid].append({
                "pred_str": pred_str.strip(),
                "rank_score": rank_score,
                "span_score": span_score,
                "pred_sp": pred_sp
            })
    acc = []
    for qid, res in id2result.items():
        res.sort(key=lambda x: x[1], reverse=True)
        acc.append(res[0][0] == 1)
    logger.info(f"evaluated {len(id2result)} questions...")
    logger.info(f'chain ranking em: {np.mean(acc)}')

    best_em, best_f1, best_joint_em, best_joint_f1, best_sp_em, best_sp_f1 = 0, 0, 0, 0, 0, 0
    best_res = None
    if fixed_thresh:
        lambdas = [fixed_thresh]
    else:
        # selecting threshhold on the dev data
        lambdas = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

    for lambda_ in lambdas:
        ems, f1s, sp_ems, sp_f1s, joint_ems, joint_f1s = [], [], [], [], [], []
        results = collections.defaultdict(dict)
        for qid, res in id2result.items():
            ans_res = id2answer[qid]
            ans_res.sort(key=lambda x: lambda_ * x["rank_score"] + (1 - lambda_) * x["span_score"], reverse=True)
            top_pred = ans_res[0]["pred_str"]
            top_pred_sp = ans_res[0]["pred_sp"]

            results["answer"][qid] = top_pred
            results["sp"][qid] = top_pred_sp

            ems.append(exact_match_score(top_pred, id2gold[qid][0]))
            f1, prec, recall = f1_score(top_pred, id2gold[qid][0])
            f1s.append(f1)

            if args.sp_pred:
                metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0}
                update_sp(metrics, top_pred_sp, id2goldsp[qid])
                sp_ems.append(metrics['sp_em'])
                sp_f1s.append(metrics['sp_f1'])
                # joint metrics
                joint_prec = prec * metrics["sp_prec"]
                joint_recall = recall * metrics["sp_recall"]
                if joint_prec + joint_recall > 0:
                    joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
                else:
                    joint_f1 = 0.
                joint_em = ems[-1] * sp_ems[-1]
                joint_ems.append(joint_em)
                joint_f1s.append(joint_f1)

        if args.sp_pred:
            if best_joint_f1 < np.mean(joint_f1s):
                best_joint_f1 = np.mean(joint_f1s)
                best_joint_em = np.mean(joint_ems)
                best_sp_f1 = np.mean(sp_f1s)
                best_sp_em = np.mean(sp_ems)
                best_f1 = np.mean(f1s)
                best_em = np.mean(ems)
                best_res = results
        else:
            if best_f1 < np.mean(f1s):
                best_f1 = np.mean(f1s)
                best_em = np.mean(ems)

        logger.info(f".......Using combination factor {lambda_}......")
        logger.info(f'answer em: {np.mean(ems)}, count: {len(ems)}')
        logger.info(f'answer f1: {np.mean(f1s)}, count: {len(f1s)}')
        logger.info(f'sp em: {np.mean(sp_ems)}, count: {len(sp_ems)}')
        logger.info(f'sp f1: {np.mean(sp_f1s)}, count: {len(sp_f1s)}')
        logger.info(f'joint em: {np.mean(joint_ems)}, count: {len(joint_ems)}')
        logger.info(f'joint f1: {np.mean(joint_f1s)}, count: {len(joint_f1s)}')
    logger.info(f"Best joint F1 from combination {best_f1}")
    if args.save_prediction != "":
        json.dump(best_res, open(f"{args.save_prediction}", "w"))

    model.train()
    return {"em": best_em, "f1": best_f1, "joint_em": best_joint_em, "joint_f1": best_joint_f1, "sp_em": best_sp_em, "sp_f1": best_sp_f1}