in model/utils/disambiguator_evaluation.py [0:0]
def evaluate_disambiguation(gt_labels, model_results, record_instance_results=None):
"""Evaluates disambiguation using golden labels and model predictions.
Args:
gt_labels: Ground truth labels.
model_results: Generated labels.
record_instance_results: Path to save instance-level metrics.
"""
gt_label_pool = {ii["dialogue_idx"]: ii for ii in gt_labels["dialogue_data"]}
predictions = []
num_evaluations = 0
for model_datum in model_results:
dialog_id = model_datum["dialog_id"]
for round_datum in model_datum["predictions"]:
round_id = round_datum["turn_id"]
predicted_label = round_datum["disambiguation_label"]
gt_datum = gt_label_pool[dialog_id]["dialogue"][round_id]
assert "disambiguation_label" in gt_datum, "Turn not to be evaluated!"
gt_label = gt_datum["disambiguation_label"]
predictions.append(gt_label == predicted_label)
# Add the result to datum and save it back.
if record_instance_results:
round_datum["disambiguation_accuracy"] = gt_label == predicted_label
print(f"# Instances evaluated: {len(predictions)}")
# Record and save per instance results.
if record_instance_results:
print("Saving per instance result: {}".format(record_instance_results))
with open(record_instance_results, "w") as file_id:
json.dump(model_results, file_id)
return np.mean(predictions), np.std(predictions) / np.sqrt(len(predictions))