in eval-vis-model.py [0:0]
def run_episode(self, seed, create_vid):
if create_vid:
episode_dir = f'{self.eval_dir}/{seed:02d}'
os.makedirs(episode_dir, exist_ok=True)
self.env.set_seed(seed)
obs = self.env.reset()
domain_name = self.domain_name
horizon = self.exp.agent.horizon
device = 'cuda'
done = False
total_reward = 0.
reward = 0.
args = self.args
env = self.env
exp = self.exp
replay_buffer = self.exp.replay_buffer
step = 0
ps = []
while not done:
if create_vid:
if 'quadruped' in domain_name:
camera_id = 2
else:
camera_id = 0
frame = env.render(
mode='rgb_array',
height=256,
width=256,
camera_id=camera_id,
)
env_fname = f'{episode_dir}/env_{step:04d}.png'
plt.imsave(env_fname, frame)
if self.exp.cfg.normalize_obs:
mu, sigma = replay_buffer.get_obs_stats()
obs = (obs - mu) / sigma
obs = torch.FloatTensor(obs).to(device)
if args.mode == 'mean':
action_seq, _, _ = self.exp.agent.dx.unroll_policy(
obs.unsqueeze(0), exp.agent.actor,
sample=False, last_u=True)
elif args.mode == 'sample':
action_seq, _, _ = self.exp.agent.dx.unroll_policy(
obs.unsqueeze(0), exp.agent.actor,
sample=True, last_u=True)
elif args.mode == 'ctrl':
exp.agent.ctrl.num_samples = 100 # TODO
action_seq = self.exp.agent.ctrl.forward(
exp.agent.dx,
exp.agent.actor,
obs,
exp.agent.critic,
return_seq=True,
)
action_seq = action_seq[:-1]
else:
assert False
if action_seq.ndimension() == 3:
action_seq = action_seq.squeeze(dim=1)
action = action_seq[0]
action = action.clamp(min=env.action_space.low.min(),
max=env.action_space.high.max())
if action.ndimension() == 1:
# TODO: This is messy, shouldn't be so sensitive to the dim here.
action = action.unsqueeze(0)
if create_vid:
def get_nominal_states(obs, actions):
assert obs.ndimension() == 1
assert actions.ndimension() == 2
obs = obs.unsqueeze(0)
pred_obs = exp.agent.dx.unroll(obs, actions.unsqueeze(1)).squeeze(1)
pred_obs = torch.cat((obs, pred_obs), dim=0)
return pred_obs
# if env._max_episode_steps - env._elapsed_steps > exp.agent.horizon:
if env._max_episode_steps - step > exp.agent.horizon:
true_xs = [obs.cpu()]
true_rews = [reward]
if 'gym' in domain_name:
freeze = utils.freeze_mbbl_env
elif domain_name == 'Humanoid-v2' or 'mbpo' in domain_name:
freeze = utils.freeze_gym_env
else:
freeze = utils.freeze_env
with freeze(env):
for t in range(horizon):
xt, rt, done, _ = env.step(utils.to_np(action_seq[t]))
if self.exp.cfg.normalize_obs:
mu, sigma = replay_buffer.get_obs_stats()
xt = (xt - mu) / sigma
true_xs.append(xt)
true_rews.append(rt)
true_xs = np.stack(true_xs)
true_rews = np.stack(true_rews)
max_obs = torch.from_numpy(true_xs).abs().max(axis=0).values.float().detach()
I = max_obs > self.max_obs
self.max_obs[I] = 1.1*max_obs[I]
if true_rews.min() < self.reward_bounds[0]:
self.reward_bounds[0] = 1.1*true_rews.min().item()
if true_rews.max() > self.reward_bounds[1]:
self.reward_bounds[1] = 1.1*true_rews.max().item()
else:
true_xs = true_rews = None
n_sample = 1
pred_xs = []
pred_rews = []
pred_dones = []
for i in range(n_sample):
pred_x = get_nominal_states(obs.squeeze(), action_seq[:-1])
max_obs = pred_x.abs().max(axis=0).values.cpu().detach()
I = max_obs > self.max_obs
self.max_obs[I] = 1.1*max_obs[I]
xu = torch.cat((pred_x, action_seq), dim=-1)
# xu = pred_x
pred_rew = exp.agent.rew(xu)
if pred_rew.min() < self.reward_bounds[0]:
self.reward_bounds[0] = 1.1*pred_rew.min().item()
if pred_rew.max() > self.reward_bounds[1]:
self.reward_bounds[1] = 1.1*pred_rew.max().item()
pred_done = exp.agent.done(xu).sigmoid()
pred_xs.append(pred_x.squeeze())
pred_rews.append(pred_rew.squeeze())
pred_dones.append(pred_done.squeeze())
pred_xs = [x.cpu() for x in pred_xs]
pred_rews = [x.cpu() for x in pred_rews]
pred_dones = [x.cpu() for x in pred_dones]
action_seq = action_seq.cpu()
def f():
preds_fname = os.path.join(episode_dir,
f'preds_{step:04d}.png')
self.plot_obs_rew(
true_xs, pred_xs, true_rews, pred_rews, pred_dones, preds_fname)
ctrl_fname = f'{episode_dir}/ctrl_{step:04d}.png'
self.plot_ctrl(action_seq, fname=ctrl_fname)
fname = f'{episode_dir}/{step:04d}.png'
os.system(f'convert {preds_fname} -trim {preds_fname}')
os.system(f'convert {ctrl_fname} -trim {ctrl_fname}')
if self.args.vid_mode == 'highlight':
# os.system('convert -gravity west -append '
# f'{preds_fname} {ctrl_fname} {fname}')
os.system(f'convert {preds_fname} -resize x300 {fname}')
os.system(f'convert {env_fname} -resize 300x300! {env_fname}')
os.system(f'convert +append {env_fname} {fname} -resize x300 {fname}')
# os.system(f'convert {fname} -resize 1328x150! {fname}')
elif 'pendulum' in domain_name:
os.system('convert -gravity center -append '
f'{preds_fname} {ctrl_fname} {fname}')
os.system(f'convert {fname} -resize x700 {fname}')
os.system(f'convert -gravity center {env_fname} -resize x700 {env_fname}')
os.system('convert -gravity center +append '
f'{env_fname} {fname} {fname}')
else:
os.system('convert -gravity center +append -resize x700 '
f'{env_fname} {preds_fname} {fname}')
os.system('convert -gravity center -append -resize 1200x '
f'{fname} {ctrl_fname} {fname}')
if self.args.no_mp:
f()
else:
p = Process(target=f)
p.start()
ps.append(p)
obs, reward, done, _ = env.step(utils.to_np(action.squeeze(0)))
total_reward += reward
print(
f'--- Step {step} -- Total Rew: {total_reward:.2f} -- Step Rew: {reward:.2f}'
)
step += 1
if args.n_steps is not None and step > args.n_steps:
done = True
if create_vid:
for p in ps:
p.join()
os.system(
f'ffmpeg -y -framerate {self.args.framerate} -i {episode_dir}/%04d.png -q 3 {episode_dir}/vid.mp4'
)
return total_reward