data/envs/atari/create_atari_dataset.py (134 lines of code) (raw):

import datasets import numpy as np import torch from datasets import Dataset, concatenate_datasets from huggingface_hub import HfApi, upload_folder from PIL import Image from sample_factory.algo.learning.learner import Learner from sample_factory.algo.sampling.batched_sampling import preprocess_actions from sample_factory.algo.utils.action_distributions import argmax_actions from sample_factory.algo.utils.env_info import extract_env_info from sample_factory.algo.utils.make_env import make_env_func_batched from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor from sample_factory.cfg.arguments import load_from_checkpoint from sample_factory.enjoy import visualize_policy_inputs from sample_factory.model.actor_critic import create_actor_critic from sample_factory.model.model_utils import get_rnn_size from sample_factory.utils.attr_dict import AttrDict from sample_factory.utils.typing import Config from sample_factory.utils.utils import log from sf_examples.envpool.atari.train_envpool_atari import parse_atari_args, register_atari_components def push_to_hf(dir_path: str, repo_name: str): _ = HfApi().create_repo(repo_id=repo_name, private=False, exist_ok=True, repo_type="dataset") upload_folder( repo_id=repo_name, folder_path=dir_path, path_in_repo=".", ignore_patterns=[".git/*"], repo_type="dataset" ) # most of this function is redundant as it is copied from sample.enjoy.enjoy def create_atari_dataset(cfg: Config): cfg = load_from_checkpoint(cfg) eval_env_frameskip: int = cfg.env_frameskip if cfg.eval_env_frameskip is None else cfg.eval_env_frameskip assert ( cfg.env_frameskip % eval_env_frameskip == 0 ), f"{cfg.env_frameskip=} must be divisible by {eval_env_frameskip=}" render_action_repeat: int = cfg.env_frameskip // eval_env_frameskip cfg.env_frameskip = cfg.eval_env_frameskip = eval_env_frameskip log.debug(f"Using frameskip {cfg.env_frameskip} and {render_action_repeat=} for evaluation") cfg.num_envs = 1 cfg.env_agents = 1 render_mode = "human" if cfg.save_video: render_mode = "rgb_array" elif cfg.no_render: render_mode = None env = make_env_func_batched( cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0), render_mode=render_mode ) env_info = extract_env_info(env, cfg) if hasattr(env.unwrapped, "reset_on_init"): # reset call ruins the demo recording for VizDoom env.unwrapped.reset_on_init = False actor_critic = create_actor_critic(cfg, env.observation_space, env.action_space) actor_critic.eval() device = torch.device("cpu" if cfg.device == "cpu" else "cuda") actor_critic.model_to_device(device) policy_id = cfg.policy_index name_prefix = {"latest": "checkpoint", "best": "best"}[cfg.load_checkpoint_kind] checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*") checkpoint_dict = Learner.load_checkpoint(checkpoints, device) actor_critic.load_state_dict(checkpoint_dict["model"]) num_frames = 0 obs, infos = env.reset() rnn_states = torch.zeros([env.num_agents, get_rnn_size(cfg)], dtype=torch.float32, device=device) image_observations = [] rewards = [] discrete_actions = [] ep_image_observations = [] ep_rewards = [] ep_discrete_actions = [] with torch.no_grad(): while num_frames < cfg.max_num_frames: obs["obs"] = obs["obs"][0] normalized_obs = prepare_and_normalize_obs(actor_critic, obs) if not cfg.no_render: visualize_policy_inputs(normalized_obs) policy_outputs = actor_critic(normalized_obs, rnn_states) # sample actions from the distribution by default actions = policy_outputs["actions"] if cfg.eval_deterministic: action_distribution = actor_critic.action_distribution() actions = argmax_actions(action_distribution) # actions shape should be [num_agents, num_actions] even if it's [1, 1] if actions.ndim == 1: actions = unsqueeze_tensor(actions, dim=-1) actions = preprocess_actions(env_info, actions) rnn_states = policy_outputs["new_rnn_states"] # store s in buffer ep_image_observations.append(Image.fromarray(np.transpose(obs["obs"][0].cpu().numpy(), (1, 2, 0)))) obs, rew, terminated, truncated, infos = env.step([actions]) done = make_dones(terminated, truncated).item() # store a,r, d in buffer ep_rewards.append(rew.item()) ep_discrete_actions.append(actions.item()) num_frames += 1 if done: # fictious done rnn_states[0] = torch.zeros([get_rnn_size(cfg)], dtype=torch.float32, device=device) if infos[0]["terminated"].item(): image_observations.append(ep_image_observations) discrete_actions.append(np.array(ep_discrete_actions).astype(np.int64)) rewards.append(np.array(ep_rewards).astype(np.float32)) ep_image_observations = [] ep_discrete_actions = [] ep_rewards = [] log.info(f"Episode rewards: {np.sum(rewards[-1]):.3f}") env.close() task = cfg.env.split("_")[1] # Fix task names (see see https://huggingface.co/datasets/jat-project/jat-dataset/discussions/21 to 24) task = "asteroids" if task == "asteroid" else task task = "kungfumaster" if task == "kongfumaster" else task task = "montezumarevenge" if task == "montezuma" else task task = "privateeye" if task == "privateye" else task d = { "image_observations": image_observations, "discrete_actions": discrete_actions, "rewards": rewards, } features = datasets.Features( { "image_observations": datasets.Sequence(datasets.Image()), "discrete_actions": datasets.Sequence(datasets.Value("int64")), "rewards": datasets.Sequence(datasets.Value("float32")), } ) ds = [ Dataset.from_dict({k: [v[idx]] for k, v in d.items()}, features=features) for idx in range(len(d["image_observations"])) ] dataset = concatenate_datasets(ds) dataset = dataset.train_test_split(test_size=0.1, writer_batch_size=1) HfApi().create_branch("jat-project/jat-dataset", branch="new_breakout", exist_ok=True, repo_type="dataset") dataset.push_to_hub( "jat-project/jat-dataset", config_name=f"atari-{task}", branch="new_breakout", ) def main(): """Script entry point.""" register_atari_components() cfg = parse_atari_args(evaluation=True) status = create_atari_dataset(cfg) return status if __name__ == "__main__": main()