def plot_obs_rew()

in eval-vis-model.py [0:0]


    def plot_obs_rew(self, true_xs, pred_xs, true_rews, pred_rews, pred_dones, fname):
        domain_name = self.domain_name
        bounds = (-self.max_obs, self.max_obs)
        reward_bounds = self.reward_bounds

        gridspec_kw = {'wspace': 0, 'hspace': 0}
        if self.args.vid_mode == 'highlight':
            # fig, axs = plt.subplots(2, 3, figsize=(4, 3), gridspec_kw=gridspec_kw)
            fig, axs = plt.subplots(3, 4, figsize=(4, 3), gridspec_kw=gridspec_kw)
        elif 'cheetah' in domain_name.lower():
            fig, axs = plt.subplots(3, 6, figsize=(14, 10), gridspec_kw=gridspec_kw)
        elif 'walker' in domain_name:
            fig, axs = plt.subplots(5, 5, figsize=(14, 10), gridspec_kw=gridspec_kw)
        elif domain_name == 'mbpo_humanoid':
            fig, axs = plt.subplots(6, 8, figsize=(16, 10), gridspec_kw=gridspec_kw)
        elif domain_name == 'mbpo_ant':
            fig, axs = plt.subplots(5, 6, figsize=(14, 10), gridspec_kw=gridspec_kw)
        elif 'humanoid' in domain_name.lower():
            fig, axs = plt.subplots(8, 9, figsize=(16, 10), gridspec_kw=gridspec_kw)
        elif 'pendulum' in domain_name:
            fig, axs = plt.subplots(4, 1, figsize=(6, 10), gridspec_kw=gridspec_kw)
        elif 'hopper' in domain_name:
            fig, axs = plt.subplots(4, 3, figsize=(14, 10), gridspec_kw=gridspec_kw)
        elif 'swimmer' in domain_name:
            fig, axs = plt.subplots(3, 3, figsize=(10, 10), gridspec_kw=gridspec_kw)
        else:
            fig, axs = plt.subplots(5, 5, figsize=(14, 10), gridspec_kw=gridspec_kw)

        axs = axs.ravel()
        if self.args.vid_mode != 'highlight':
            add_label(axs[0], 'States', fontsize=20)
        for ax in axs:
            # ax.axis('off')
            ax.get_xaxis().set_ticklabels([])
            ax.get_yaxis().set_ticklabels([])
            ax.patch.set_edgecolor('black')

        horizon_p1, state_dim = pred_xs[0].shape
        horizon = horizon_p1-1

        for i in range(state_dim):
            if i >= len(axs)-1:
                # print(f'Warning: Skipping state dim {i}')
                continue
            ax = axs[i]
            if true_xs is not None:
                ax.plot(true_xs[:, i], color='k', label='Ground Truth')

            color = None
            for j in range(len(pred_xs)):
                p, = ax.plot(utils.to_np(pred_xs[j][:, i]), alpha=1., color=color)
                color = p.get_color()
            ax.set_ylim(bounds[0][i], bounds[1][i])
            ax.set_xlim(0, horizon)

            if self.args.vid_mode != 'highlight':
                ax.axhline(color='k', linestyle='--', alpha=0.4)

        rew_ax = axs[-1]
        if true_rews is not None:
            rew_ax.plot(true_rews, alpha=0.5, color='k')

        color = plt.rcParams['axes.prop_cycle'].by_key()['color'][1]
        for j in range(len(pred_rews)):
            rew_ax.plot(utils.to_np(pred_rews[j]), alpha=1., color=color)
        rew_ax.set_ylim(*reward_bounds)
        rew_ax.set_xlim(0, horizon)

        rew_ax.get_xaxis().set_ticklabels([])
        rew_ax.get_yaxis().set_ticklabels([])
        if self.args.vid_mode != 'highlight':
            add_label(rew_ax, 'Rewards', fontsize=20)

        if self.args.show_dones:
            done_ax = rew_ax.twinx()
            for j in range(len(pred_dones)):
                done_ax.plot(utils.to_np(pred_dones[j]), alpha=1.)
            done_ax.set_ylim(-0.1, 1.1)
            done_ax.set_xlim(0, horizon)
            done_ax.get_xaxis().set_ticklabels([])
            done_ax.get_yaxis().set_ticklabels([])

        fig.tight_layout()
        fig.savefig(fname)
        plt.close(fig)