mmlu_eval.py (113 lines of code) (raw):

""" Measuring Massive Multitask Language Understanding Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt https://arxiv.org/abs/2009.03300 """ import random import re import pandas from . import common from .common import ( HTML_JINJA, MULTILINGUAL_ANSWER_PATTERN_TEMPLATE, MULTILINGUAL_ANSWER_REGEXES, format_multichoice_question, normalize_extracted_answer, normalize_response, ) from .types import Eval, EvalResult, SamplerBase, SingleEvalResult subject2category = { "abstract_algebra": "stem", "anatomy": "other", "astronomy": "stem", "business_ethics": "other", "clinical_knowledge": "other", "college_biology": "stem", "college_chemistry": "stem", "college_computer_science": "stem", "college_mathematics": "stem", "college_medicine": "other", "college_physics": "stem", "computer_security": "stem", "conceptual_physics": "stem", "econometrics": "social_sciences", "electrical_engineering": "stem", "elementary_mathematics": "stem", "formal_logic": "humanities", "global_facts": "other", "high_school_biology": "stem", "high_school_chemistry": "stem", "high_school_computer_science": "stem", "high_school_european_history": "humanities", "high_school_geography": "social_sciences", "high_school_government_and_politics": "social_sciences", "high_school_macroeconomics": "social_sciences", "high_school_mathematics": "stem", "high_school_microeconomics": "social_sciences", "high_school_physics": "stem", "high_school_psychology": "social_sciences", "high_school_statistics": "stem", "high_school_us_history": "humanities", "high_school_world_history": "humanities", "human_aging": "other", "human_sexuality": "social_sciences", "international_law": "humanities", "jurisprudence": "humanities", "logical_fallacies": "humanities", "machine_learning": "stem", "management": "other", "marketing": "other", "medical_genetics": "other", "miscellaneous": "other", "moral_disputes": "humanities", "moral_scenarios": "humanities", "nutrition": "other", "philosophy": "humanities", "prehistory": "humanities", "professional_accounting": "other", "professional_law": "humanities", "professional_medicine": "other", "professional_psychology": "social_sciences", "public_relations": "social_sciences", "security_studies": "social_sciences", "sociology": "social_sciences", "us_foreign_policy": "social_sciences", "virology": "other", "world_religions": "humanities", } class MMLUEval(Eval): def __init__(self, num_examples: int | None = None, language: str = "EN-US"): if language != "EN-US": url = f"https://openaipublic.blob.core.windows.net/simple-evals/mmlu_{language}.csv" else: url = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv" df = pandas.read_csv(url) examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: examples = random.Random(0).sample(examples, num_examples) self.examples = examples def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(row: dict): prompt_messages = [ sampler._pack_message( content=format_multichoice_question(row), role="user" ) ] response_text = normalize_response(sampler(prompt_messages)) extracted_answer = None for answer_regex in MULTILINGUAL_ANSWER_REGEXES: regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex) match = re.search(regex, response_text) if match: extracted_answer = normalize_extracted_answer(match.group(1)) break score = 1.0 if extracted_answer == row["Answer"] else 0.0 html = common.jinja_env.from_string(HTML_JINJA).render( prompt_messages=prompt_messages, next_message=dict(content=response_text, role="assistant"), score=score, correct_answer=row["Answer"], extracted_answer=extracted_answer, ) convo = prompt_messages + [dict(content=response_text, role="assistant")] category = subject2category.get(row["Subject"], "other") return SingleEvalResult( html=html, score=score, metrics={category: score}, convo=convo ) results = common.map_with_progress(fn, self.examples) return common.aggregate_results(results)