scripts/eval_jat.py (143 lines of code) (raw):

#!/usr/bin/env python3 """Eval a JAT model""" import json import logging import os import sys import warnings from dataclasses import dataclass, field from typing import List, Optional import numpy as np import torch from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoProcessor, HfArgumentParser from jat.eval.rl import TASK_NAME_TO_ENV_ID, make from jat.utils import normalize, push_to_hub, save_video_grid @dataclass class ModelArguments: """ Arguments pertaining to which model/config we are going to train from. """ model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, ) trust_remote_code: bool = field( default=False, metadata={ "help": ( "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" "should only be set to `True` for repositories you trust and in which you have read the code, as it " "will execute code present on the Hub on your local machine." ) }, ) @dataclass class EvaluationArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ tasks: List[str] = field(default_factory=list, metadata={"help": "Tasks to train on."}) use_cpu: bool = field(default=False, metadata={"help": "Use CPU instead of GPU."}) save_video: bool = field(default=False, metadata={"help": "Save video of the evaluation."}) num_episodes: int = field(default=2, metadata={"help": "Number of episodes to evaluate on."}) push_to_hub: bool = field(default=False, metadata={"help": "Push the model to the hub."}) repo_id: Optional[str] = field(default=None, metadata={"help": "Repository ID to push to."}) def get_default_device() -> torch.device: if torch.backends.mps.is_available() and torch.backends.mps.is_built(): return torch.device("mps") elif torch.cuda.is_available(): return torch.device("cuda") else: return torch.device("cpu") def eval_rl(model, processor, task, eval_args): # Create the environment env_kwargs = {} if task.startswith("atari"): env_kwargs["clip_reward"] = False if eval_args.save_video: env_kwargs["render_mode"] = "rgb_array" env = make(task, **env_kwargs) context_window = 32 if task.startswith("atari") else 256 scores = [] frames = [] for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False): observation, _ = env.reset() reward = None rewards = [] done = False model.reset_rl() # remove KV Cache while not done: action = model.get_next_action( processor, **observation, reward=reward, action_space=env.action_space, context_window=context_window ) observation, reward, termined, truncated, info = env.step(action) done = termined or truncated # Handle "fake done" for atari if done and task.startswith("atari"): if "episode" not in info: observation, info = env.reset() done = False # Update the return rewards.append(reward) # Render the environment if eval_args.save_video: frames.append(np.array(env.render(), dtype=np.uint8)) scores.append(sum(rewards)) env.close() raw_mean, raw_std = np.mean(scores), np.std(scores) # Normalize the scores norm_scores = normalize(scores, task, "expert") if norm_scores is not None: # Can be None if random is better than expert norm_mean, norm_std = np.mean(norm_scores), np.std(norm_scores) tqdm.write( f"Task {task} Raw score: {raw_mean:.2f} ± {raw_std:.2f}\t" f"Normalized score: {norm_mean:.2f} ± {norm_std:.2f}" ) else: tqdm.write(f"Task {task} Raw score: {raw_mean:.2f} ± {raw_std:.2f}") # Resize images by 1/3 to limit memory usage (the video is reduced anyway when aggregated with the others) if eval_args.save_video: import cv2 frames = [cv2.resize(frame, (0, 0), fx=1 / 3, fy=1 / 3) for frame in frames] return scores, frames, env.metadata["render_fps"] def main(): parser = HfArgumentParser((ModelArguments, EvaluationArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, eval_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, eval_args = parser.parse_args_into_dataclasses() # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Set the tasks tasks = eval_args.tasks for domain in ["atari", "babyai", "metaworld", "mujoco"]: if domain in tasks: tasks.remove(domain) tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)]) device = torch.device("cpu") if eval_args.use_cpu else get_default_device() model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, trust_remote_code=model_args.trust_remote_code ).to(device) processor = AutoProcessor.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, trust_remote_code=model_args.trust_remote_code ) evaluations = {} video_list = [] input_fps = [] for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True): if task in TASK_NAME_TO_ENV_ID.keys(): scores, frames, fps = eval_rl(model, processor, task, eval_args) evaluations[task] = scores # Save the video if eval_args.save_video: video_list.append(frames) input_fps.append(fps) else: warnings.warn(f"Task {task} is not supported.") # Extract mean and std, and save scores dict eval_path = f"{model_args.model_name_or_path}/evaluations.json" if not os.path.exists(f"{model_args.model_name_or_path}"): os.makedirs(f"{model_args.model_name_or_path}") if evaluations: with open(eval_path, "w") as file: json.dump(evaluations, file) # Save the video if eval_args.save_video: replay_path = f"{model_args.model_name_or_path}/replay.mp4" save_video_grid(video_list, input_fps, replay_path, output_fps=30, max_length_seconds=180) else: replay_path = None # Push the model to the hub if eval_args.push_to_hub: assert eval_args.repo_id is not None, "You need to specify a repo_id to push to." push_to_hub(model, processor, eval_args.repo_id, replay_path=replay_path, eval_path=eval_path) if __name__ == "__main__": main()