mlebench/registry.py (113 lines of code) (raw):

from dataclasses import dataclass from pathlib import Path from typing import Callable from appdirs import user_cache_dir from mlebench.grade_helpers import Grader from mlebench.utils import get_logger, get_module_dir, get_repo_dir, import_fn, load_yaml logger = get_logger(__name__) DEFAULT_DATA_DIR = (Path(user_cache_dir()) / "mle-bench" / "data").resolve() @dataclass(frozen=True) class Competition: id: str name: str description: str grader: Grader answers: Path gold_submission: Path sample_submission: Path competition_type: str prepare_fn: Callable[[Path, Path, Path], Path] raw_dir: Path private_dir: Path public_dir: Path checksums: Path leaderboard: Path def __post_init__(self): assert isinstance(self.id, str), "Competition id must be a string." assert isinstance(self.name, str), "Competition name must be a string." assert isinstance(self.description, str), "Competition description must be a string." assert isinstance(self.grader, Grader), "Competition grader must be of type Grader." assert isinstance(self.answers, Path), "Competition answers must be a Path." assert isinstance(self.gold_submission, Path), "Gold submission must be a Path." assert isinstance(self.sample_submission, Path), "Sample submission must be a Path." assert isinstance(self.competition_type, str), "Competition type must be a string." assert isinstance(self.checksums, Path), "Checksums must be a Path." assert isinstance(self.leaderboard, Path), "Leaderboard must be a Path." assert len(self.id) > 0, "Competition id cannot be empty." assert len(self.name) > 0, "Competition name cannot be empty." assert len(self.description) > 0, "Competition description cannot be empty." assert len(self.competition_type) > 0, "Competition type cannot be empty." @staticmethod def from_dict(data: dict) -> "Competition": grader = Grader.from_dict(data["grader"]) try: return Competition( id=data["id"], name=data["name"], description=data["description"], grader=grader, answers=data["answers"], sample_submission=data["sample_submission"], gold_submission=data["gold_submission"], competition_type=data["competition_type"], prepare_fn=data["prepare_fn"], raw_dir=data["raw_dir"], public_dir=data["public_dir"], private_dir=data["private_dir"], checksums=data["checksums"], leaderboard=data["leaderboard"], ) except KeyError as e: raise ValueError(f"Missing key {e} in competition config!") class Registry: def __init__(self, data_dir: Path = DEFAULT_DATA_DIR): self._data_dir = data_dir.resolve() def get_competition(self, competition_id: str) -> Competition: """Fetch the competition from the registry.""" config_path = self.get_competitions_dir() / competition_id / "config.yaml" config = load_yaml(config_path) checksums_path = self.get_competitions_dir() / competition_id / "checksums.yaml" leaderboard_path = self.get_competitions_dir() / competition_id / "leaderboard.csv" description_path = get_repo_dir() / config["description"] description = description_path.read_text() preparer_fn = import_fn(config["preparer"]) answers = self.get_data_dir() / config["dataset"]["answers"] gold_submission = answers if "gold_submission" in config["dataset"]: gold_submission = self.get_data_dir() / config["dataset"]["gold_submission"] sample_submission = self.get_data_dir() / config["dataset"]["sample_submission"] raw_dir = self.get_data_dir() / competition_id / "raw" private_dir = self.get_data_dir() / competition_id / "prepared" / "private" public_dir = self.get_data_dir() / competition_id / "prepared" / "public" return Competition.from_dict( { **config, "description": description, "answers": answers, "sample_submission": sample_submission, "gold_submission": gold_submission, "prepare_fn": preparer_fn, "raw_dir": raw_dir, "private_dir": private_dir, "public_dir": public_dir, "checksums": checksums_path, "leaderboard": leaderboard_path, } ) def get_competitions_dir(self) -> Path: """Retrieves the competition directory within the registry.""" return get_module_dir() / "competitions" def get_splits_dir(self) -> Path: """Retrieves the splits directory within the repository.""" return get_repo_dir() / "experiments" / "splits" def get_lite_competition_ids(self) -> list[str]: """List all competition IDs for the lite version (low complexity competitions).""" lite_competitions_file = self.get_splits_dir() / "low.txt" with open(lite_competitions_file, "r") as f: competition_ids = f.read().splitlines() return competition_ids def get_data_dir(self) -> Path: """Retrieves the data directory within the registry.""" return self._data_dir def set_data_dir(self, new_data_dir: Path) -> "Registry": """Sets the data directory within the registry.""" return Registry(new_data_dir) def list_competition_ids(self) -> list[str]: """List all competition IDs available in the registry, sorted alphabetically.""" competition_configs = self.get_competitions_dir().rglob("config.yaml") competition_ids = [f.parent.stem for f in sorted(competition_configs)] return competition_ids registry = Registry()