in data/envs/atari/create_atari_dataset.py [0:0]
def create_atari_dataset(cfg: Config):
cfg = load_from_checkpoint(cfg)
eval_env_frameskip: int = cfg.env_frameskip if cfg.eval_env_frameskip is None else cfg.eval_env_frameskip
assert (
cfg.env_frameskip % eval_env_frameskip == 0
), f"{cfg.env_frameskip=} must be divisible by {eval_env_frameskip=}"
render_action_repeat: int = cfg.env_frameskip // eval_env_frameskip
cfg.env_frameskip = cfg.eval_env_frameskip = eval_env_frameskip
log.debug(f"Using frameskip {cfg.env_frameskip} and {render_action_repeat=} for evaluation")
cfg.num_envs = 1
cfg.env_agents = 1
render_mode = "human"
if cfg.save_video:
render_mode = "rgb_array"
elif cfg.no_render:
render_mode = None
env = make_env_func_batched(
cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0), render_mode=render_mode
)
env_info = extract_env_info(env, cfg)
if hasattr(env.unwrapped, "reset_on_init"):
# reset call ruins the demo recording for VizDoom
env.unwrapped.reset_on_init = False
actor_critic = create_actor_critic(cfg, env.observation_space, env.action_space)
actor_critic.eval()
device = torch.device("cpu" if cfg.device == "cpu" else "cuda")
actor_critic.model_to_device(device)
policy_id = cfg.policy_index
name_prefix = {"latest": "checkpoint", "best": "best"}[cfg.load_checkpoint_kind]
checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*")
checkpoint_dict = Learner.load_checkpoint(checkpoints, device)
actor_critic.load_state_dict(checkpoint_dict["model"])
num_frames = 0
obs, infos = env.reset()
rnn_states = torch.zeros([env.num_agents, get_rnn_size(cfg)], dtype=torch.float32, device=device)
image_observations = []
rewards = []
discrete_actions = []
ep_image_observations = []
ep_rewards = []
ep_discrete_actions = []
with torch.no_grad():
while num_frames < cfg.max_num_frames:
obs["obs"] = obs["obs"][0]
normalized_obs = prepare_and_normalize_obs(actor_critic, obs)
if not cfg.no_render:
visualize_policy_inputs(normalized_obs)
policy_outputs = actor_critic(normalized_obs, rnn_states)
# sample actions from the distribution by default
actions = policy_outputs["actions"]
if cfg.eval_deterministic:
action_distribution = actor_critic.action_distribution()
actions = argmax_actions(action_distribution)
# actions shape should be [num_agents, num_actions] even if it's [1, 1]
if actions.ndim == 1:
actions = unsqueeze_tensor(actions, dim=-1)
actions = preprocess_actions(env_info, actions)
rnn_states = policy_outputs["new_rnn_states"]
# store s in buffer
ep_image_observations.append(Image.fromarray(np.transpose(obs["obs"][0].cpu().numpy(), (1, 2, 0))))
obs, rew, terminated, truncated, infos = env.step([actions])
done = make_dones(terminated, truncated).item()
# store a,r, d in buffer
ep_rewards.append(rew.item())
ep_discrete_actions.append(actions.item())
num_frames += 1
if done: # fictious done
rnn_states[0] = torch.zeros([get_rnn_size(cfg)], dtype=torch.float32, device=device)
if infos[0]["terminated"].item():
image_observations.append(ep_image_observations)
discrete_actions.append(np.array(ep_discrete_actions).astype(np.int64))
rewards.append(np.array(ep_rewards).astype(np.float32))
ep_image_observations = []
ep_discrete_actions = []
ep_rewards = []
log.info(f"Episode rewards: {np.sum(rewards[-1]):.3f}")
env.close()
task = cfg.env.split("_")[1]
# Fix task names (see see https://huggingface.co/datasets/jat-project/jat-dataset/discussions/21 to 24)
task = "asteroids" if task == "asteroid" else task
task = "kungfumaster" if task == "kongfumaster" else task
task = "montezumarevenge" if task == "montezuma" else task
task = "privateeye" if task == "privateye" else task
d = {
"image_observations": image_observations,
"discrete_actions": discrete_actions,
"rewards": rewards,
}
features = datasets.Features(
{
"image_observations": datasets.Sequence(datasets.Image()),
"discrete_actions": datasets.Sequence(datasets.Value("int64")),
"rewards": datasets.Sequence(datasets.Value("float32")),
}
)
ds = [
Dataset.from_dict({k: [v[idx]] for k, v in d.items()}, features=features)
for idx in range(len(d["image_observations"]))
]
dataset = concatenate_datasets(ds)
dataset = dataset.train_test_split(test_size=0.1, writer_batch_size=1)
HfApi().create_branch("jat-project/jat-dataset", branch="new_breakout", exist_ok=True, repo_type="dataset")
dataset.push_to_hub(
"jat-project/jat-dataset",
config_name=f"atari-{task}",
branch="new_breakout",
)