data/envs/atari/generate_random_score.py (114 lines of code) (raw):

import json import os from multiprocessing import Pool import numpy as np import torch from sample_factory.algo.utils.make_env import make_env_func_batched from sample_factory.utils.attr_dict import AttrDict from sf_examples.atari.train_atari import parse_atari_args, register_atari_components FILENAME = "jat/eval/rl/scores_dict.json" TASK_NAMES = [ "atari-alien", "atari-amidar", "atari-assault", "atari-asterix", "atari-asteroids", "atari-atlantis", "atari-bankheist", "atari-battlezone", "atari-beamrider", "atari-berzerk", "atari-bowling", "atari-boxing", "atari-breakout", "atari-centipede", "atari-choppercommand", "atari-crazyclimber", "atari-defender", "atari-demonattack", "atari-doubledunk", "atari-enduro", "atari-fishingderby", "atari-freeway", "atari-frostbite", "atari-gopher", "atari-gravitar", "atari-hero", "atari-icehockey", "atari-jamesbond", "atari-kangaroo", "atari-krull", "atari-kungfumaster", "atari-montezumarevenge", "atari-mspacman", "atari-namethisgame", "atari-phoenix", "atari-pitfall", "atari-pong", "atari-privateeye", "atari-qbert", "atari-riverraid", "atari-roadrunner", "atari-robotank", "atari-seaquest", "atari-skiing", "atari-solaris", "atari-spaceinvaders", "atari-stargunner", "atari-surround", "atari-tennis", "atari-timepilot", "atari-tutankham", "atari-upndown", "atari-venture", "atari-videopinball", "atari-wizardofwor", "atari-yarsrevenge", "atari-zaxxon", ] TOT_NUM_EPISODES = 100 def generate_random_score(task_name): cfg = parse_atari_args(evaluation=True) env_id = task_name.replace("-", "_") if env_id == "atari_asteroids": env_id = "atari_asteroid" if env_id == "atari_montezumarevenge": env_id = "atari_montezuma" if env_id == "atari_kungfumaster": env_id = "atari_kongfumaster" if env_id == "atari_privateeye": env_id = "atari_privateye" cfg.env = env_id eval_env_frameskip = cfg.env_frameskip if cfg.eval_env_frameskip is None else cfg.eval_env_frameskip cfg.env_frameskip = cfg.eval_env_frameskip = eval_env_frameskip cfg.num_envs = 16 env = make_env_func_batched(cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0)) ep_rewards = [] env.reset() with torch.no_grad(): while len(ep_rewards) < TOT_NUM_EPISODES: _, _, _, _, infos = env.step(np.array([env.action_space.sample() for _ in range(cfg.num_envs)])) for info in infos: if "episode" in info: ep_rewards.append(info["episode"]["r"][0]) if len(ep_rewards) % 10 == 0: print(f"Task {task_name} - progress {int(len(ep_rewards) / TOT_NUM_EPISODES * 100)}%") env.close() # Load the scores dictionary if not os.path.exists(FILENAME): scores_dict = {} else: with open(FILENAME, "r") as file: scores_dict = json.load(file) # Add the random scores to the dictionary if task_name not in scores_dict: scores_dict[task_name] = {} scores_dict[task_name]["random"] = {"mean": float(np.mean(ep_rewards)), "std": float(np.std(ep_rewards))} # Save the dictionary to a file with open(FILENAME, "w") as file: scores_dict = { task: {agent: scores_dict[task][agent] for agent in sorted(scores_dict[task])} for task in sorted(scores_dict) } json.dump(scores_dict, file, indent=4) if __name__ == "__main__": register_atari_components() with Pool(32) as p: p.map(generate_random_score, TASK_NAMES)