mlebench/grade_helpers.py (171 lines of code) (raw):

"""Helper classes related to grading""" import inspect from dataclasses import dataclass from datetime import datetime from typing import Any, Optional import pandas as pd from mlebench.utils import get_logger, import_fn logger = get_logger(__name__) class Grader: def __init__(self, name: str, grade_fn: str) -> None: self.name = name self.grade_fn = import_fn(grade_fn) assert isinstance(self.name, str), "Grader name must be a string." assert len(self.name) > 0, "Grader name cannot be empty." def is_lower_better(self, leaderboard: pd.DataFrame) -> bool: """ Determines if a lower score is better based on the leaderboard. Returns True if lower scores are better, False otherwise. """ scores = leaderboard["score"] top_score = scores.iloc[0] bottom_score = scores.iloc[-1] return bool(top_score < bottom_score) @staticmethod def from_dict(data: dict) -> "Grader": return Grader(**data) def __call__(self, submission: pd.DataFrame, answers: Any) -> Optional[float]: """ Runs the grading function on a submission, returning the score rounded to 5 decimal places. """ try: score = self.grade_fn(submission, answers) except InvalidSubmissionError as e: logger.warning(f"Invalid submission: {e}") return None except Exception as e: try: fpath = inspect.getfile(self.grade_fn) line_number = inspect.getsourcelines(self.grade_fn)[1] fn_info = f"{fpath}:{line_number}" except TypeError: fn_info = str(self.grade_fn) logger.error(f"Unexpected error during grading: {e}. Check {fn_info}") return None rounded_score = round(score, 5) return rounded_score def rank_score(self, score: Optional[float], leaderboard: pd.DataFrame) -> dict: """ Ranks a score based on the leaderboard. Returns a dictionary of bools with the following keys: - gold_medal: bool - silver_medal: bool - bronze_medal: bool - above_median: bool - gold_threshold: float - silver_threshold: float - bronze_threshold: float - median_threshold: float """ assert "score" in leaderboard.columns, "Leaderboard must have a `score` column." lower_is_better = self.is_lower_better(leaderboard) num_teams = len(leaderboard) scores = leaderboard["score"] def get_score_at_position(position: int) -> float: """ Returns the score at the given position in the leaderboard. Raises an IndexError if the position is out of bounds. """ if position - 1 >= len(scores) or position < 1: raise IndexError("Position out of bounds in the leaderboard.") return scores.iloc[position - 1] def get_thresholds(num_teams: int) -> tuple[float, float, float, float]: """ Returns the thresholds for medals based on kaggle.com/progression Returns a tuple of thresholds in the order of gold, silver, bronze, median """ if 1 <= num_teams < 100: gold_threshold = get_score_at_position(max(1, int(num_teams * 0.1))) silver_threshold = get_score_at_position(max(1, int(num_teams * 0.2))) bronze_threshold = get_score_at_position(max(1, int(num_teams * 0.4))) elif 100 <= num_teams < 250: gold_threshold = get_score_at_position(10) silver_threshold = get_score_at_position(max(1, int(num_teams * 0.2))) bronze_threshold = get_score_at_position(max(1, int(num_teams * 0.4))) elif 250 <= num_teams < 1000: gold_threshold = get_score_at_position(10 + int(num_teams * 0.002)) silver_threshold = get_score_at_position(50) bronze_threshold = get_score_at_position(100) elif num_teams >= 1000: gold_threshold = get_score_at_position(10 + int(num_teams * 0.002)) silver_threshold = get_score_at_position(max(1, int(num_teams * 0.05))) bronze_threshold = get_score_at_position(max(1, int(num_teams * 0.1))) else: raise ValueError("Number of teams in leaderboard must be greater than 0.") median_threshold = scores.median() return ( float(gold_threshold), float(silver_threshold), float(bronze_threshold), float(median_threshold), ) gold_threshold, silver_threshold, bronze_threshold, median_threshold = get_thresholds( num_teams ) if score is None: return { "gold_medal": False, "silver_medal": False, "bronze_medal": False, "above_median": False, "gold_threshold": gold_threshold, "silver_threshold": silver_threshold, "bronze_threshold": bronze_threshold, "median_threshold": median_threshold, } assert isinstance( score, (float, int) ), f"Expected `score` to be a `float` or `int` but got a {type(score)}." gold_medal = score <= gold_threshold if lower_is_better else score >= gold_threshold silver_medal = not gold_medal and ( score <= silver_threshold if lower_is_better else score >= silver_threshold ) bronze_medal = ( not gold_medal and not silver_medal and (score <= bronze_threshold if lower_is_better else score >= bronze_threshold) ) above_median = score < median_threshold if lower_is_better else score > median_threshold return { "gold_medal": gold_medal, "silver_medal": silver_medal, "bronze_medal": bronze_medal, "above_median": above_median, "gold_threshold": gold_threshold, "silver_threshold": silver_threshold, "bronze_threshold": bronze_threshold, "median_threshold": median_threshold, } @dataclass(frozen=True) class CompetitionReport: competition_id: str score: float | None gold_threshold: float silver_threshold: float bronze_threshold: float median_threshold: float any_medal: bool gold_medal: bool silver_medal: bool bronze_medal: bool above_median: bool submission_exists: bool valid_submission: bool is_lower_better: bool created_at: datetime submission_path: str def to_dict(self) -> dict: # Convert all values to JSON-compatible types explicitly return { "competition_id": self.competition_id, "score": float(self.score) if self.score is not None else None, "gold_threshold": float(self.gold_threshold), "silver_threshold": float(self.silver_threshold), "bronze_threshold": float(self.bronze_threshold), "median_threshold": float(self.median_threshold), "any_medal": bool(self.any_medal), "gold_medal": bool(self.gold_medal), "silver_medal": bool(self.silver_medal), "bronze_medal": bool(self.bronze_medal), "above_median": bool(self.above_median), "submission_exists": bool(self.submission_exists), "valid_submission": bool(self.valid_submission), "is_lower_better": bool(self.is_lower_better), "created_at": self.created_at.isoformat(), # Serialize datetime "submission_path": self.submission_path, } @classmethod def from_dict(cls, data: dict) -> "CompetitionReport": data = data.copy() # Avoid accidentally mutating the original dictionary typed_data = { "competition_id": data["competition_id"], "score": float(data["score"]) if data["score"] is not None else None, "gold_threshold": float(data["gold_threshold"]), "silver_threshold": float(data["silver_threshold"]), "bronze_threshold": float(data["bronze_threshold"]), "median_threshold": float(data["median_threshold"]), "any_medal": bool(data["any_medal"]), "gold_medal": bool(data["gold_medal"]), "silver_medal": bool(data["silver_medal"]), "bronze_medal": bool(data["bronze_medal"]), "above_median": bool(data["above_median"]), "submission_exists": bool(data["submission_exists"]), "valid_submission": bool(data["valid_submission"]), "is_lower_better": bool(data["is_lower_better"]), "created_at": datetime.fromisoformat(data["created_at"]), "submission_path": data["submission_path"], } return cls(**typed_data) class InvalidSubmissionError(Exception): """ A custom exception for when the agent submission cannot be graded. """ pass