in evals/elsuite/skill_acquisition/eval.py [0:0]
def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]:
samples = self.get_samples()
self.rng.shuffle(samples)
samples = samples[: self.n_samples] if self.n_samples is not None else samples
results = self.eval_all_samples(recorder, samples)
non_retrieval_results = [result["non_retrieval"] for result in results]
retrieval_results = [result["retrieval"] for result in results]
baseline_accuracy = get_accuracy(non_retrieval_results)
baseline_std = get_bootstrap_accuracy_std(non_retrieval_results)
retrieval_accuracy = get_accuracy(retrieval_results)
retrieval_std = get_bootstrap_accuracy_std(retrieval_results)
delta_accuracy = retrieval_accuracy - baseline_accuracy
# TODO: decide which metric to report – propagated standard deviation
# from bootstrapping or standard error of the mean estimated from repeats
# of the eval experiments.
delta_std = get_std_of_difference(baseline_std, retrieval_std)
ctx_len_exceeded_rate = sum(
1 for result in retrieval_results if result["ctx_len_exceeded"]
) / len(retrieval_results)
timeout_rate = sum(
1 for result in retrieval_results if result["interaction_timed_out"]
) / len(retrieval_results)
num_translation_samples = len(
[result for result in retrieval_results if result["question_type"] == "translation"]
)
num_non_translation_samples = len(
[result for result in retrieval_results if result["question_type"] == "non-translation"]
)
result = {
"baseline_accuracy": baseline_accuracy,
"baseline_std": baseline_std,
"retrieval_accuracy": retrieval_accuracy,
"retrieval_std": retrieval_std,
"delta_accuracy": delta_accuracy,
"delta_std": delta_std,
"average_retrieval_precision": get_average_retrieval_precision(retrieval_results),
"average_non_retrieval_bleu_score": get_average_bleu_score(non_retrieval_results),
"average_retrieval_bleu_score": get_average_bleu_score(retrieval_results),
"average_retrieval_calls": get_average_retrieval_calls(retrieval_results),
"average_invalid_retrieval_calls": get_average_invalid_retrieval_calls(
retrieval_results
),
"ctx_len_exceeded_rate": ctx_len_exceeded_rate,
"timeout_rate": timeout_rate,
"num_samples": len(retrieval_results),
"num_translation_samples": num_translation_samples,
"num_non_translation_samples": num_non_translation_samples,
}
return result