def run_episode()

in eval-vis-model.py [0:0]


    def run_episode(self, seed, create_vid):
        if create_vid:
            episode_dir = f'{self.eval_dir}/{seed:02d}'
            os.makedirs(episode_dir, exist_ok=True)

        self.env.set_seed(seed)
        obs = self.env.reset()
        domain_name = self.domain_name
        horizon = self.exp.agent.horizon

        device = 'cuda'
        done = False
        total_reward = 0.
        reward = 0.

        args = self.args
        env = self.env
        exp = self.exp
        replay_buffer = self.exp.replay_buffer
        step = 0

        ps = []
        while not done:
            if create_vid:
                if 'quadruped' in domain_name:
                    camera_id = 2
                else:
                    camera_id = 0
                frame = env.render(
                    mode='rgb_array',
                    height=256,
                    width=256,
                    camera_id=camera_id,
                )
                env_fname = f'{episode_dir}/env_{step:04d}.png'
                plt.imsave(env_fname, frame)

            if self.exp.cfg.normalize_obs:
                mu, sigma = replay_buffer.get_obs_stats()
                obs = (obs - mu) / sigma
            obs = torch.FloatTensor(obs).to(device)

            if args.mode == 'mean':
                action_seq, _, _ = self.exp.agent.dx.unroll_policy(
                    obs.unsqueeze(0), exp.agent.actor,
                    sample=False, last_u=True)
            elif args.mode == 'sample':
                action_seq, _, _ = self.exp.agent.dx.unroll_policy(
                    obs.unsqueeze(0), exp.agent.actor,
                    sample=True, last_u=True)
            elif args.mode == 'ctrl':
                exp.agent.ctrl.num_samples = 100 # TODO
                action_seq = self.exp.agent.ctrl.forward(
                    exp.agent.dx,
                    exp.agent.actor,
                    obs,
                    exp.agent.critic,
                    return_seq=True,
                )
                action_seq = action_seq[:-1]
            else:
                assert False

            if action_seq.ndimension() == 3:
                action_seq = action_seq.squeeze(dim=1)

            action = action_seq[0]
            action = action.clamp(min=env.action_space.low.min(),
                                max=env.action_space.high.max())
            if action.ndimension() == 1:
                # TODO: This is messy, shouldn't be so sensitive to the dim here.
                action = action.unsqueeze(0)

            if create_vid:
                def get_nominal_states(obs, actions):
                    assert obs.ndimension() == 1
                    assert actions.ndimension() == 2
                    obs = obs.unsqueeze(0)
                    pred_obs = exp.agent.dx.unroll(obs, actions.unsqueeze(1)).squeeze(1)
                    pred_obs = torch.cat((obs, pred_obs), dim=0)
                    return pred_obs

                # if env._max_episode_steps - env._elapsed_steps > exp.agent.horizon:
                if env._max_episode_steps - step > exp.agent.horizon:
                    true_xs = [obs.cpu()]
                    true_rews = [reward]
                    if 'gym' in domain_name:
                        freeze = utils.freeze_mbbl_env
                    elif domain_name == 'Humanoid-v2' or 'mbpo' in domain_name:
                        freeze = utils.freeze_gym_env
                    else:
                        freeze = utils.freeze_env
                    with freeze(env):
                        for t in range(horizon):
                            xt, rt, done, _ = env.step(utils.to_np(action_seq[t]))
                            if self.exp.cfg.normalize_obs:
                                mu, sigma = replay_buffer.get_obs_stats()
                                xt = (xt - mu) / sigma
                            true_xs.append(xt)
                            true_rews.append(rt)
                    true_xs = np.stack(true_xs)
                    true_rews = np.stack(true_rews)

                    max_obs = torch.from_numpy(true_xs).abs().max(axis=0).values.float().detach()
                    I = max_obs > self.max_obs
                    self.max_obs[I] = 1.1*max_obs[I]
                    if true_rews.min() < self.reward_bounds[0]:
                        self.reward_bounds[0] = 1.1*true_rews.min().item()
                    if true_rews.max() > self.reward_bounds[1]:
                        self.reward_bounds[1] = 1.1*true_rews.max().item()
                else:
                    true_xs = true_rews = None

                n_sample = 1
                pred_xs = []
                pred_rews = []
                pred_dones = []
                for i in range(n_sample):
                    pred_x = get_nominal_states(obs.squeeze(), action_seq[:-1])
                    max_obs = pred_x.abs().max(axis=0).values.cpu().detach()
                    I = max_obs > self.max_obs
                    self.max_obs[I] = 1.1*max_obs[I]
                    xu = torch.cat((pred_x, action_seq), dim=-1)
                    # xu = pred_x
                    pred_rew = exp.agent.rew(xu)
                    if pred_rew.min() < self.reward_bounds[0]:
                        self.reward_bounds[0] = 1.1*pred_rew.min().item()
                    if pred_rew.max() > self.reward_bounds[1]:
                        self.reward_bounds[1] = 1.1*pred_rew.max().item()

                    pred_done = exp.agent.done(xu).sigmoid()

                    pred_xs.append(pred_x.squeeze())
                    pred_rews.append(pred_rew.squeeze())
                    pred_dones.append(pred_done.squeeze())

                pred_xs = [x.cpu() for x in pred_xs]
                pred_rews = [x.cpu() for x in pred_rews]
                pred_dones = [x.cpu() for x in pred_dones]
                action_seq = action_seq.cpu()

                def f():
                    preds_fname = os.path.join(episode_dir,
                                            f'preds_{step:04d}.png')
                    self.plot_obs_rew(
                        true_xs, pred_xs, true_rews, pred_rews, pred_dones, preds_fname)

                    ctrl_fname = f'{episode_dir}/ctrl_{step:04d}.png'
                    self.plot_ctrl(action_seq, fname=ctrl_fname)

                    fname = f'{episode_dir}/{step:04d}.png'
                    os.system(f'convert {preds_fname} -trim {preds_fname}')
                    os.system(f'convert {ctrl_fname} -trim {ctrl_fname}')
                    if self.args.vid_mode == 'highlight':
                        # os.system('convert -gravity west -append '
                        #         f'{preds_fname} {ctrl_fname} {fname}')
                        os.system(f'convert {preds_fname} -resize x300 {fname}')
                        os.system(f'convert {env_fname} -resize 300x300! {env_fname}')
                        os.system(f'convert +append {env_fname} {fname} -resize x300 {fname}')
                        # os.system(f'convert {fname} -resize 1328x150! {fname}')
                    elif 'pendulum' in domain_name:
                        os.system('convert -gravity center -append '
                                f'{preds_fname} {ctrl_fname} {fname}')
                        os.system(f'convert {fname} -resize x700 {fname}')
                        os.system(f'convert -gravity center {env_fname} -resize x700 {env_fname}')
                        os.system('convert -gravity center +append '
                                f'{env_fname} {fname} {fname}')
                    else:
                        os.system('convert -gravity center +append -resize x700 '
                                f'{env_fname} {preds_fname} {fname}')
                        os.system('convert -gravity center -append -resize 1200x '
                                f'{fname} {ctrl_fname} {fname}')

                if self.args.no_mp:
                    f()
                else:
                    p = Process(target=f)
                    p.start()
                    ps.append(p)

            obs, reward, done, _ = env.step(utils.to_np(action.squeeze(0)))
            total_reward += reward
            print(
                f'--- Step {step} -- Total Rew: {total_reward:.2f} -- Step Rew: {reward:.2f}'
            )
            step += 1
            if args.n_steps is not None and step > args.n_steps:
                done = True

        if create_vid:
            for p in ps:
                p.join()

            os.system(
                f'ffmpeg -y -framerate {self.args.framerate} -i {episode_dir}/%04d.png -q 3 {episode_dir}/vid.mp4'
            )

        return total_reward