def _eval()

in mdr/qa/qa_trainer.py [0:0]


    def _eval(self) -> dict:
        print("Start evaluation of the model", flush=True)
        job_env = submitit.JobEnvironment()
        args = self._train_cfg
        eval_dataloader = self._test_loader
        model = self._state.model
        model.eval()
        id2result = collections.defaultdict(list)
        id2answer = collections.defaultdict(list)
        id2gold = {}
        id2goldsp = {}
        for batch in tqdm(eval_dataloader):
            batch_to_feed = move_to_cuda(batch["net_inputs"])
            batch_qids = batch["qids"]
            batch_labels = batch["net_inputs"]["label"].view(-1).tolist()
            with torch.no_grad():
                outputs = model(batch_to_feed)
                scores = outputs["rank_score"]
                scores = scores.view(-1).tolist()
                sp_scores = outputs["sp_score"]
                sp_scores = sp_scores.float().masked_fill(batch_to_feed["sent_offsets"].eq(0), float("-inf")).type_as(sp_scores)
                batch_sp_scores = sp_scores.sigmoid()
                # ans_type_predicted = torch.argmax(outputs["ans_type_logits"], dim=1).view(-1).tolist()
                outs = [outputs["start_logits"], outputs["end_logits"]]
            for qid, label, score in zip(batch_qids, batch_labels, scores):
                id2result[qid].append((label, score))

            # answer prediction
            span_scores = outs[0][:, :, None] + outs[1][:, None]
            max_seq_len = span_scores.size(1)
            span_mask = np.tril(np.triu(np.ones((max_seq_len, max_seq_len)), 0), args.max_ans_len)
            span_mask = span_scores.data.new(max_seq_len, max_seq_len).copy_(torch.from_numpy(span_mask))
            span_scores_masked = span_scores.float().masked_fill((1 - span_mask[None].expand_as(span_scores)).bool(), -1e10).type_as(span_scores)
            start_position = span_scores_masked.max(dim=2)[0].max(dim=1)[1]
            end_position = span_scores_masked.max(dim=2)[1].gather(
                1, start_position.unsqueeze(1)).squeeze(1)
            answer_scores = span_scores_masked.max(dim=2)[0].max(dim=1)[0].tolist()
            para_offset = batch['para_offsets']
            start_position_ = list(
                np.array(start_position.tolist()) - np.array(para_offset))
            end_position_ = list(
                np.array(end_position.tolist()) - np.array(para_offset))  
 
            for idx, qid in enumerate(batch_qids):
                id2gold[qid] = batch["gold_answer"][idx]
                id2goldsp[qid] = batch["sp_gold"][idx]
                rank_score = scores[idx]
                sp_score = batch_sp_scores[idx].tolist()
                start = start_position_[idx]
                end = end_position_[idx]
                span_score = answer_scores[idx]

                tok_to_orig_index = batch['tok_to_orig_index'][idx]
                doc_tokens = batch['doc_tokens'][idx]
                wp_tokens = batch['wp_tokens'][idx]
                orig_doc_start = tok_to_orig_index[start]
                orig_doc_end = tok_to_orig_index[end]
                orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)]
                tok_tokens = wp_tokens[start:end+1]
                tok_text = " ".join(tok_tokens)
                tok_text = tok_text.replace(" ##", "")
                tok_text = tok_text.replace("##", "")
                tok_text = tok_text.strip()
                tok_text = " ".join(tok_text.split())
                orig_text = " ".join(orig_tokens)
                pred_str = get_final_text(tok_text, orig_text, do_lower_case=True, verbose_logging=False)

                pred_sp = []
                passages = batch["passages"][idx]
                for passage, sent_offset in zip(passages, [0, len(passages[0]["sents"])]):
                    for idx, _ in enumerate(passage["sents"]):
                        try:
                            if sp_score[idx + sent_offset] > 0.5:
                                pred_sp.append([passage["title"], idx])
                        except:
                            continue
                id2answer[qid].append((pred_str.strip(), rank_score, span_score, pred_sp))

        acc = []
        for qid, res in id2result.items():
            res.sort(key=lambda x: x[1], reverse=True)
            acc.append(res[0][0] == 1)
        print(f"evaluated {len(id2result)} questions...", flush=True)
        print(f'chain ranking em: {np.mean(acc)}', flush=True)

        best_em, best_f1, best_joint_em, best_joint_f1 = 0, 0, 0, 0
        lambdas = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
        for lambda_ in lambdas:
            ems, f1s = [], []
            sp_ems, sp_f1s = [], []
            joint_ems, joint_f1s = [], []
            for qid, res in id2result.items():
                ans_res = id2answer[qid]
                ans_res.sort(key=lambda x: lambda_ * x[1] + (1 - lambda_) * x[2], reverse=True)
                top_pred = ans_res[0][0]
                ems.append(exact_match_score(top_pred, id2gold[qid][0]))
                f1, prec, recall = f1_score(top_pred, id2gold[qid][0])
                f1s.append(f1)

                top_pred_sp = ans_res[0][3]
                metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0}
                update_sp(metrics, top_pred_sp, id2goldsp[qid])
                sp_ems.append(metrics['sp_em'])
                sp_f1s.append(metrics['sp_f1'])

                # joint metrics
                joint_prec = prec * metrics["sp_prec"]
                joint_recall = recall * metrics["sp_recall"]
                if joint_prec + joint_recall > 0:
                    joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
                else:
                    joint_f1 = 0
                joint_em = ems[-1] * sp_ems[-1]
                joint_ems.append(joint_em)
                joint_f1s.append(joint_f1)

            if best_joint_f1 < np.mean(joint_f1s):
                best_joint_f1 = np.mean(joint_f1s)
                best_joint_em = np.mean(joint_ems)
                best_f1 = np.mean(f1s)
                best_em = np.mean(ems)

            print(f".......Using combination factor {lambda_}......", flush=True)
            print(f'answer em: {np.mean(ems)}, count: {len(ems)}', flush=True)
            print(f'answer f1: {np.mean(f1s)}, count: {len(f1s)}', flush=True)
            print(f'sp em: {np.mean(sp_ems)}, count: {len(sp_ems)}', flush=True)
            print(f'sp f1: {np.mean(sp_f1s)}, count: {len(sp_f1s)}', flush=True)
            print(f'joint em: {np.mean(joint_ems)}, count: {len(joint_ems)}', flush=True)
            print(f'joint f1: {np.mean(joint_f1s)}, count: {len(joint_f1s)}', flush=True)
        print(f"Best joint EM/F1 from combination {best_em}/{best_f1}", flush=True)

        model.train()
        return {"em": best_em, "f1": best_f1, "joint_em": best_joint_em, "joint_f1": best_joint_f1}