def evaluate_disambiguation()

in model/utils/disambiguator_evaluation.py [0:0]


def evaluate_disambiguation(gt_labels, model_results, record_instance_results=None):
    """Evaluates disambiguation using golden labels and model predictions.

    Args:
        gt_labels: Ground truth labels.
        model_results: Generated labels.
        record_instance_results: Path to save instance-level metrics.
    """
    gt_label_pool = {ii["dialogue_idx"]: ii for ii in gt_labels["dialogue_data"]}

    predictions = []
    num_evaluations = 0
    for model_datum in model_results:
        dialog_id = model_datum["dialog_id"]
        for round_datum in model_datum["predictions"]:
            round_id = round_datum["turn_id"]
            predicted_label = round_datum["disambiguation_label"]
            gt_datum = gt_label_pool[dialog_id]["dialogue"][round_id]

            assert "disambiguation_label" in gt_datum, "Turn not to be evaluated!"
            gt_label = gt_datum["disambiguation_label"]
            predictions.append(gt_label == predicted_label)

            # Add the result to datum and save it back.
            if record_instance_results:
                round_datum["disambiguation_accuracy"] = gt_label == predicted_label

    print(f"# Instances evaluated: {len(predictions)}")
    # Record and save per instance results.
    if record_instance_results:
        print("Saving per instance result: {}".format(record_instance_results))
        with open(record_instance_results, "w") as file_id:
            json.dump(model_results, file_id)
    return np.mean(predictions), np.std(predictions) / np.sqrt(len(predictions))