evals/elsuite/modelgraded/classify.py (97 lines of code) (raw):

""" Generic eval that uses a prompt + classification. """ from collections import Counter from random import Random from typing import Any, Optional, Union import evals import evals.record from evals.elsuite.modelgraded.classify_utils import classify, sample_and_concat_n_completions from evals.elsuite.utils import PromptFn, scrub_formatting_from_prompt class ModelBasedClassify(evals.Eval): def __init__( self, modelgraded_spec: str, *args, modelgraded_spec_args: Optional[dict[str, dict[str, str]]] = None, sample_kwargs: Optional[dict[str, Any]] = None, eval_kwargs: Optional[dict[str, Any]] = None, multicomp_n: Union[int, str] = 1, eval_type: Optional[str] = None, match_fn: Optional[str] = None, metaeval: bool = False, **kwargs, ): super().__init__(*args, **kwargs) # treat last completion_fn as eval_completion_fn self.eval_completion_fn = self.completion_fns[-1] if len(self.completion_fns) > 1: self.completion_fns = self.completion_fns[:-1] n_models = len(self.completion_fns) self.sample_kwargs = {"max_tokens": 1024} self.sample_kwargs.update(sample_kwargs or {}) self.eval_kwargs = {"max_tokens": 1024} self.eval_kwargs.update(eval_kwargs or {}) self.metaeval = metaeval self.modelgraded_spec_args = modelgraded_spec_args or {} self.eval_type = eval_type self.match_fn = match_fn if multicomp_n == "from_models": assert n_models > 1 self.multicomp_n = n_models else: assert isinstance(multicomp_n, int) self.multicomp_n = multicomp_n if len(self.completion_fns) > 1: assert self.multicomp_n == n_models self.mg = self.registry.get_modelgraded_spec(modelgraded_spec) def eval_sample(self, test_sample: dict, rng: Random) -> None: """Evaluate a single sample. Recorded metrics are always: one of the self.choice_strings, or "__invalid__". """ # process test_sample for k in self.mg.input_outputs: test_sample[k] = scrub_formatting_from_prompt(test_sample[k]) # run policy completions completions = {} for k, v in self.mg.input_outputs.items(): if v in test_sample: # test_sample already has completion, skip. continue if self.multicomp_n > 1: completion = sample_and_concat_n_completions( self.completion_fns, prompt=test_sample[k], template_i=self.mg.output_template, sample_kwargs=self.sample_kwargs, n=self.multicomp_n, ) else: get_input_completion = PromptFn( test_sample[k], completion_fn=self.completion_fn, **self.sample_kwargs ) completion, _ = get_input_completion() completions[v] = completion # run modelgraded eval metrics = {} choice, info = classify( mg=self.mg, completion_fn=self.eval_completion_fn, completion_kwargs=self.eval_kwargs, eval_type=self.eval_type, n=self.multicomp_n, match_fn=self.match_fn, format_kwargs={**completions, **test_sample, **self.modelgraded_spec_args}, ) metrics.update(dict(choice=choice, score=info["score"])) # run metaeval if requested if self.metaeval: assert "choice" in test_sample metrics["metascore"] = choice == test_sample["choice"] evals.record.record_metrics(**metrics) return choice def run(self, recorder): samples = self.get_samples() self.eval_all_samples(recorder, samples) record_metrics = {} all_sample_metrics = recorder.get_metrics() if not all_sample_metrics: return record_metrics # record the counts choices = [m["choice"] for m in all_sample_metrics] counts = dict(Counter(choices)) record_metrics.update({f"counts/{k}": v for k, v in counts.items()}) # record the scores scores = [m["score"] for m in all_sample_metrics if m["score"] is not None] if scores: record_metrics["score"] = sum(scores) / len(scores) metascores = [m["metascore"] for m in all_sample_metrics if "metascore" in m] if metascores: record_metrics["metascore"] = sum(metascores) / len(metascores) return record_metrics