def run()

in mbrl/diagnostics/planet_visualizer.py [0:0]


    def run(self):
        current_step = 0
        true_obs = []
        true_total_reward = 0.0
        actions = []
        obs = self.env.reset()
        self.agent.reset()

        for step in range(self.start_step + self.lookahead):
            action = self.agent.act(obs)
            next_obs, reward, done, _ = self.env.step(action)
            if step >= self.start_step:
                true_obs.append(obs)
                actions.append(action)
                true_total_reward += reward
            obs = next_obs
            if done:
                break
            current_step += 1

        # Now check what the model thinks will happen with the same sequence of actions
        cur_obs = true_obs[0].copy()
        pred_total_reward = 0.0
        latent = self.model_env.reset(cur_obs[None, :], return_as_np=False)
        pred_obs = [self.model.render(latent)[0]]
        for a in actions:
            latent, reward, *_ = self.model_env.step(a.copy()[None, :])
            pred_obs.append(self.model.render(latent)[0])
            pred_total_reward += reward.item()

        print(
            f"True total reward: {true_total_reward}. Predicted total reward: {pred_total_reward}"
        )

        filenames = []
        for idx in range(self.lookahead):
            fname = self.vis_dir / f"frame_{idx}.png"
            filenames.append(fname)
            fig, axs = plt.subplots(1, 2, figsize=(12, 6))
            axs[0].imshow(pred_obs[idx].astype(np.uint8))
            axs[1].imshow(true_obs[idx].transpose(1, 2, 0))

            # save frame
            plt.savefig(fname)
            plt.close()

        with imageio.get_writer(
            self.vis_dir
            / f"visualization_{self.start_step}_{self.lookahead}_{self.seed}.gif",
            mode="I",
        ) as writer:
            for filename in filenames:
                image = imageio.imread(filename)
                writer.append_data(image)

        # Remove files
        for filename in set(filenames):
            os.remove(filename)