in scripts/eval_jat.py [0:0]
def main():
parser = HfArgumentParser((ModelArguments, EvaluationArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, eval_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, eval_args = parser.parse_args_into_dataclasses()
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Set the tasks
tasks = eval_args.tasks
for domain in ["atari", "babyai", "metaworld", "mujoco"]:
if domain in tasks:
tasks.remove(domain)
tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)])
device = torch.device("cpu") if eval_args.use_cpu else get_default_device()
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, trust_remote_code=model_args.trust_remote_code
).to(device)
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, trust_remote_code=model_args.trust_remote_code
)
evaluations = {}
video_list = []
input_fps = []
for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True):
if task in TASK_NAME_TO_ENV_ID.keys():
scores, frames, fps = eval_rl(model, processor, task, eval_args)
evaluations[task] = scores
# Save the video
if eval_args.save_video:
video_list.append(frames)
input_fps.append(fps)
else:
warnings.warn(f"Task {task} is not supported.")
# Extract mean and std, and save scores dict
eval_path = f"{model_args.model_name_or_path}/evaluations.json"
if not os.path.exists(f"{model_args.model_name_or_path}"):
os.makedirs(f"{model_args.model_name_or_path}")
if evaluations:
with open(eval_path, "w") as file:
json.dump(evaluations, file)
# Save the video
if eval_args.save_video:
replay_path = f"{model_args.model_name_or_path}/replay.mp4"
save_video_grid(video_list, input_fps, replay_path, output_fps=30, max_length_seconds=180)
else:
replay_path = None
# Push the model to the hub
if eval_args.push_to_hub:
assert eval_args.repo_id is not None, "You need to specify a repo_id to push to."
push_to_hub(model, processor, eval_args.repo_id, replay_path=replay_path, eval_path=eval_path)