scripts/download_all_datasets.py (20 lines of code) (raw):

#!/usr/bin/env python3 """Load and generate batch for all datasets from the JAT dataset""" import argparse import os from datasets import get_dataset_config_names, load_dataset from datasets.config import HF_DATASETS_CACHE from jat.eval.rl.core import TASK_NAME_TO_ENV_ID parser = argparse.ArgumentParser() parser.add_argument("--tasks", nargs="+", default=[]) tasks = parser.parse_args().tasks if tasks == ["all"]: tasks = get_dataset_config_names("jat-project/jat-dataset-tokenized") # get all task names from jat dataset for domain in ["atari", "babyai", "metaworld", "mujoco"]: if domain in tasks: tasks.remove(domain) tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)]) for task in tasks: print(f"Loading {task}...") cache_path = f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}" if not os.path.exists(cache_path): dataset = load_dataset("jat-project/jat-dataset-tokenized", task) dataset.save_to_disk(cache_path)