in interaction_exploration/viz_trainer.py [0:0]
def enjoy(self):
self.init_viz()
test_episodes = [(f'FloorPlan{np.random.randint(1, 31)}', np.random.randint(10000)) for _ in range(10)]
self.config.defrost()
self.config.ENV.TEST_EPISODES = test_episodes
self.config.ENV.TEST_EPISODE_COUNT = len(test_episodes)
self.config.NUM_PROCESSES = 1
self.config.MODE = 'eval'
self.config.freeze()
checkpoint_path = self.config.LOAD
ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")
ppo_cfg = self.config.RL.PPO
logger.info(f"env config: {self.config}")
# choose the right env wrapper depending on class
env_name = 'ThorEnjoyVanilla'
if self.config.ENV.ENV_NAME in ['ThorObjectCoverage-v0']:
env_name = 'ThorEnjoyCycler'
elif self.config.ENV.ENV_NAME in ['ThorNavigationNovelty-v0']:
env_name = 'ThorEnjoyCyclerFixedView'
self.envs = construct_envs(self.config, get_env_class(env_name))
self._setup_actor_critic_agent(ppo_cfg)
logger.info(checkpoint_path)
logger.info(f"num_steps: {self.config.ENV.NUM_STEPS}")
self.agent.load_state_dict(ckpt_dict["state_dict"])
self.actor_critic = self.agent.actor_critic
observations = self.envs.reset()
batch = self.batch_obs(observations, self.device)
current_episode_reward = torch.zeros(
self.envs.num_envs, 1, device=self.device
)
test_recurrent_hidden_states = torch.zeros(
self.actor_critic.net.num_recurrent_layers,
self.config.NUM_PROCESSES,
ppo_cfg.hidden_size,
device=self.device,
)
prev_actions = torch.zeros(
self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long
)
not_done_masks = torch.zeros(
self.config.NUM_PROCESSES, 1, device=self.device
)
stats_episodes = dict()
self.actor_critic.eval()
while True:
infos = None
last_pt = None
self.reset_fig()
for step in range(ppo_cfg.num_steps):
# ------------------------------------------- #
action = 'init' if infos is None else infos[0]['action']
ep_reward = current_episode_reward[0].item()
print(f'step: {step} | R: {ep_reward} | {action}')
viz_data = self.envs.call_at(0, 'get_viz_data')
frame = torch.from_numpy(viz_data['frame']).float().permute(2, 0, 1)/255
frame = self.add_rectangle(frame)
topdown = viz_data['topdown']
if action!='init':
traj_pt = list(viz_data['pts'][0])
traj_pt[1] = 300 - traj_pt[1]
int_pt = (max(min(viz_data['pts'][-1][0], 295), 5), 300-max(min(viz_data['pts'][-1][1], 295), 5))
if last_pt is not None:
self.ax.plot((last_pt[0], traj_pt[0]), (last_pt[1], traj_pt[1]), color=self.cmap(step/ppo_cfg.num_steps), lw=2)
last_pt = traj_pt
if viz_data['action'] in self.interactions:
if viz_data['reward']>0:
plt.plot(int_pt[0], int_pt[1], marker='o', color='Lime', alpha=0.8)
else:
plt.plot(int_pt[0], int_pt[1], marker='o', color='yellow', alpha=0.05)
self.canvas.draw()
s, (width, height) = self.canvas.print_to_buffer()
# annots = np.frombuffer(s, np.uint8).reshape((height, width, 4))
annots = Image.frombytes("RGBA", (width, height), s)
draw = ImageDraw.Draw(annots)
draw.polygon(viz_data['pts'][:3], fill=(0, 255, 255, 64))
topdown = Image.fromarray(topdown, "RGB").convert("RGBA")
topdown = Image.alpha_composite(topdown, annots)
topdown = np.array(topdown.convert("RGB"))
topdown = torch.from_numpy(topdown).float().permute(2, 0, 1)/255
grid = make_grid([frame, topdown], nrow=2)
util.show_wait(grid, T=1)
# ------------------------------------------- #
current_episodes = self.envs.current_episodes()
with torch.no_grad():
(
_,
actions,
_,
test_recurrent_hidden_states,
) = self.actor_critic.act(
batch,
test_recurrent_hidden_states,
prev_actions,
not_done_masks,
deterministic=False,
)
prev_actions.copy_(actions)
outputs = self.envs.step([a[0].item() for a in actions])
observations, rewards, dones, infos = [
list(x) for x in zip(*outputs)
]
batch = self.batch_obs(observations, self.device)
not_done_masks = torch.tensor(
[[0.0] if done else [1.0] for done in dones],
dtype=torch.float,
device=self.device,
)
rewards = torch.tensor(
rewards, dtype=torch.float, device=self.device
).unsqueeze(1)
current_episode_reward += rewards
n_envs = self.envs.num_envs
for i in range(n_envs):
# episode ended
if not_done_masks[i].item() == 0:
episode_stats = dict()
episode_stats["reward"] = current_episode_reward[i].item()
episode_stats.update(
self._extract_scalars_from_info(infos[i])
)
current_episode_reward[i] = 0
stats_episodes[
(
current_episodes[i]['scene_id'],
current_episodes[i]['episode_id'],
)
] = episode_stats
# Log info so far
num_episodes = len(stats_episodes)
aggregated_stats = dict()
for stat_key in next(iter(stats_episodes.values())).keys():
aggregated_stats[stat_key] = (
sum([v[stat_key] for v in stats_episodes.values()])
/ num_episodes
)
for k, v in aggregated_stats.items():
logger.info(f"Average episode {k}: {v:.4f} ({num_episodes} episodes)")