in src/sal/utils/math.py [0:0]
def compute_pass_at_k(x, k):
"""
Computes pass@k for predictions, using canonical forms to group and compare answers.
Args:
x (dict): A dictionary containing "preds" (list of predictions) and "answer" (correct answer).
k (int): The cutoff for pass@k.
Returns:
dict: A dictionary containing pass@k results.
"""
n = len(x["preds"])
if n == 0:
raise ValueError("No predictions found")
if x["answer"] == "":
raise ValueError("Answer is empty")
# Compute the canonical form of the correct answer
canonical_answer = memoized_canonical_form(x["answer"])
# Compute the count of predictions matching the canonical answer
c = sum(memoized_canonical_form(pred) == canonical_answer for pred in x["preds"])
# Calculate pass@k
return {f"pass@{k}": pass_at_k(n, c, k)}