mlebench/cli.py (200 lines of code) (raw):

import argparse import json from pathlib import Path from mlebench.data import download_and_prepare_dataset, ensure_leaderboard_exists from mlebench.grade import grade_csv, grade_jsonl from mlebench.registry import registry from mlebench.utils import get_logger logger = get_logger(__name__) def main(): parser = argparse.ArgumentParser(description="Runs agents on Kaggle competitions.") subparsers = parser.add_subparsers(dest="command", help="Sub-command to run.") # Prepare sub-parser parser_prepare = subparsers.add_parser( name="prepare", help="Download and prepare competitions for the MLE-bench dataset.", ) parser_prepare.add_argument( "-c", "--competition-id", help=f"ID of the competition to prepare. Valid options: {registry.list_competition_ids()}", type=str, required=False, ) parser_prepare.add_argument( "-a", "--all", help="Prepare all competitions.", action="store_true", ) parser_prepare.add_argument( "--lite", help="Prepare all the low complexity competitions (MLE-bench Lite).", action="store_true", required=False, ) parser_prepare.add_argument( "-l", "--list", help="Prepare a list of competitions specified line by line in a text file.", type=str, required=False, ) parser_prepare.add_argument( "--keep-raw", help="Keep the raw competition files after the competition has been prepared.", action="store_true", required=False, default=False, ) parser_prepare.add_argument( "--data-dir", help="Path to the directory where the data will be stored.", required=False, default=registry.get_data_dir(), ) parser_prepare.add_argument( "--overwrite-checksums", help="[For Developers] Overwrite the checksums file for the competition.", action="store_true", required=False, default=False, ) parser_prepare.add_argument( "--overwrite-leaderboard", help="[For Developers] Overwrite the leaderboard file for the competition.", action="store_true", required=False, default=False, ) parser_prepare.add_argument( "--skip-verification", help="[For Developers] Skip the verification of the checksums.", action="store_true", required=False, default=False, ) # Grade eval sub-parser parser_grade_eval = subparsers.add_parser( "grade", help="Grade a submission to the eval, comprising of several competition submissions", ) parser_grade_eval.add_argument( "--submission", help="Path to the JSONL file of submissions. Refer to README.md#submission-format for the required format.", type=str, required=True, ) parser_grade_eval.add_argument( "--output-dir", help="Path to the directory where the evaluation metrics will be saved.", type=str, required=True, ) parser_grade_eval.add_argument( "--data-dir", help="Path to the directory where the data used for grading is stored.", required=False, default=registry.get_data_dir(), ) # Grade sample sub-parser parser_grade_sample = subparsers.add_parser( name="grade-sample", help="Grade a single sample (competition) in the eval", ) parser_grade_sample.add_argument( "submission", help="Path to the submission CSV file.", type=str, ) parser_grade_sample.add_argument( "competition_id", help=f"ID of the competition to grade. Valid options: {registry.list_competition_ids()}", type=str, ) parser_grade_sample.add_argument( "--data-dir", help="Path to the directory where the data will be stored.", required=False, default=registry.get_data_dir(), ) # Dev tools sub-parser parser_dev = subparsers.add_parser("dev", help="Developer tools for extending MLE-bench.") dev_subparsers = parser_dev.add_subparsers(dest="dev_command", help="Developer command to run.") # Set up 'download-leaderboard' under 'dev' parser_download_leaderboard = dev_subparsers.add_parser( "download-leaderboard", help="Download the leaderboard for a competition.", ) parser_download_leaderboard.add_argument( "-c", "--competition-id", help=f"Name of the competition to download the leaderboard for. Valid options: {registry.list_competition_ids()}", type=str, required=False, ) parser_download_leaderboard.add_argument( "--all", help="Download the leaderboard for all competitions.", action="store_true", ) parser_download_leaderboard.add_argument( "--force", help="Force download the leaderboard, even if it already exists.", action="store_true", ) args = parser.parse_args() if args.command == "prepare": new_registry = registry.set_data_dir(Path(args.data_dir)) if args.lite: competitions = [ new_registry.get_competition(competition_id) for competition_id in new_registry.get_lite_competition_ids() ] elif args.all: competitions = [ new_registry.get_competition(competition_id) for competition_id in registry.list_competition_ids() ] elif args.list: with open(args.list, "r") as f: competition_ids = f.read().splitlines() competitions = [ new_registry.get_competition(competition_id) for competition_id in competition_ids ] else: if not args.competition_id: parser_prepare.error( "One of --lite, --all, --list, or --competition-id must be specified." ) competitions = [new_registry.get_competition(args.competition_id)] for competition in competitions: download_and_prepare_dataset( competition=competition, keep_raw=args.keep_raw, overwrite_checksums=args.overwrite_checksums, overwrite_leaderboard=args.overwrite_leaderboard, skip_verification=args.skip_verification, ) if args.command == "grade": new_registry = registry.set_data_dir(Path(args.data_dir)) submission = Path(args.submission) output_dir = Path(args.output_dir) grade_jsonl(submission, output_dir, new_registry) if args.command == "grade-sample": new_registry = registry.set_data_dir(Path(args.data_dir)) competition = new_registry.get_competition(args.competition_id) submission = Path(args.submission) report = grade_csv(submission, competition) logger.info("Competition report:") logger.info(json.dumps(report.to_dict(), indent=4)) if args.command == "dev": if args.dev_command == "download-leaderboard": if args.all: for competition_id in registry.list_competition_ids(): competition = registry.get_competition(competition_id) ensure_leaderboard_exists(competition, force=args.force) elif args.competition_id: competition = registry.get_competition(args.competition_id) ensure_leaderboard_exists(competition, force=args.force) else: parser_download_leaderboard.error( "Either --all or --competition-id must be specified." ) if __name__ == "__main__": main()