def find_best_thresh()

in utils_nlp/eval/question_answering.py [0:0]


def find_best_thresh(preds, scores, na_probs, qid_to_has_ans, unanswerable_exists=False):
    """
    Find the best threshold to determine a question is impossible to answer.

    Args:
        preds (dict): Dictionary with qa_id as keys and predicted answers as values.
        scores (dict): Dictionary with qa_id as keys and raw evaluation scores (exact_match or
            f1) as values.
        na_probs (dict): Dictionary with qa_id as keys and unanswerable probabilities as values.
        qid_to_has_ans (dict): Dictionary with qa_id as keys boolean values indicating if the
            question has answer as values.
        unanswerable_exists (bool, optional): Whether there is unanswerable questions in the data.
            Defaults to False.

    Returns:
        tuple: score after applying best threshold, best threshold, (score for answerable
            questions after applying best threshold, if unanswerable_exists=True)
    """
    num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
    # If na_prob > threshold, the question is considered as unanswerable by the prediction.
    # Initially, the threshold is 0. All questions are considered as unanswerable by the
    # predictions. So cur_score is the number of actual unanswerable questions (i.e. correctly
    # predicted as unanswerable in the data.
    cur_score = num_no_ans
    best_score = cur_score
    best_thresh = 0.0

    # Sorted in ascending order
    qid_list = sorted(na_probs, key=lambda k: na_probs[k])
    for i, qid in enumerate(qid_list):
        # When using the cur_na_prob as threshold, all predictions with na_prob > na_prob_cur are
        # considered as unanswerable. Current question is considered answerable.
        if qid not in scores:
            continue
        if qid_to_has_ans[qid]:
            # Current question has ground truth answer, the prediction is correct. The raw score
            # is added to cur_score
            diff = scores[qid]
        else:
            # Current question doesn't have ground truth answer.
            if preds[qid]:
                # Prediction is not empty, incorrect. cur_score -= 1
                diff = -1
            else:
                # Prediction is empty, correct, the original score 1 from num_no_ans is preserved.
                diff = 0
        cur_score += diff
        if cur_score > best_score:
            # When cur_score > best_score, the threshold can increase so that more questions are
            # considered as answerable and fewer questions are considered as unanswerable.
            # Imagine a PDF with two humps with some overlapping, the x axis is the na_prob. The
            # hump on the left is answerable questions and the hump on the right is unanswerable
            # questions.
            # At some point, the number of actual answerable questions decreases, and we got more
            # penalty from considering unanswerable questions as answerable than the score added
            # from actual answerable questions, we will not change the threshold anymore and the
            # optimal threshold is found.
            best_score = cur_score
            best_thresh = na_probs[qid]

    if not unanswerable_exists:
        return 100.0 * best_score / len(scores), best_thresh
    else:
        has_ans_score, has_ans_cnt = 0, 0
        for qid in qid_list:
            if not qid_to_has_ans[qid]:
                continue
            has_ans_cnt += 1

            if qid not in scores:
                continue
            has_ans_score += scores[qid]

        return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt