def plot_exp()

in svg/analysis.py [0:0]


def plot_exp(root, print_cfg=False, print_overrides=True, Qmax=None,
             obsmax=None, suptitle=None, save=None,
             plot_rew=True, N_smooth=200, N_downsample=200,
             smooth_train_rew=True):
    config = OmegaConf.load(f'{root}/.hydra/config.yaml')
    df = pd.read_csv(f'{root}/train.csv')

    def get_smooth(key):
        # it, vae_loss = smooth(df.index, df.vae_loss, N)
        it, v = df.step, df[key]
        _it = np.linspace(it.min(), it.max(), num=N_downsample)
        _v = sp.interpolate.interp1d(it, v)(_it)
        return _it, _v

    nrow, ncol = 2, 3
    fig, axs = plt.subplots(nrow, ncol, figsize=(6*ncol, 4*nrow))
    axs = axs.reshape(-1)

    ax = axs[0]
    ax.plot(*get_smooth('actor_loss'), label='Total')
    # ax.set_ylim(0, 0.3)
    ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    ax.set_title('Actor Loss')

    if 'critic_Q_loss' in df:
        ax = axs[1]
        ax.plot(*get_smooth('critic_Q_loss'))
        ax.set_ylim(0, Qmax)
        # ax.set_xlabel('1k Interactions')
        ax.set_title('Critic Loss')
        ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))

        if 'critic_recon_loss' in df:
            ax = ax.twinx()
            ax.plot(*get_smooth('critic_recon_loss'), color='red')
            ax.set_ylim(0, None)
            ax.set_ylabel('Recon Loss')
    elif 'critic_loss' in df:
        ax = axs[1]
        ax.plot(*get_smooth('critic_loss'))
        ax.set_ylim(0, Qmax)
        # ax.set_xlabel('1k Interactions')
        ax.set_title('Critic Loss')
        ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))


    if 'model_obs_loss' in df:
        ax = axs[2]
        ax.plot(*get_smooth('model_obs_loss'), label='Obs Loss')
        ax.set_ylim(0, obsmax)
        ax.set_title('Obs Loss')
        ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
        # ax.legend()

        if 'model_reward_loss' in df and plot_rew:
            ax = ax.twinx()
            ax.plot(*get_smooth('model_reward_loss'), label='Rew Loss', color='red')
            ax.set_ylabel('Rew Loss')
            ax.set_ylim(0, None)
            ax.legend()
            ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))

    ax = axs[3]
    ax.plot(*get_smooth('alpha_value'), label='alpha loss')
    ax.set_title('Alpha Value')
    ax.set_yscale('log')
    ax.set_xlabel('Interations')
    ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))

    ax = axs[4]
    ax.plot(*get_smooth('actor_entropy'))
    ax.plot(*get_smooth('actor_target_entropy'))
    ax.set_title('Actor Entropy')
    ax.set_xlabel('Interactions')
    ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))

    ax = axs[5]
    if smooth_train_rew:
        l, = ax.plot(*get_smooth('episode_reward'), alpha=0.4)
    else:
        l, = ax.plot(df.step, df.episode_reward, alpha=0.4)
    df = load_eval(root)
    if df is not None and len(df) > 0:
        if len(df) == 1:
            ax.scatter(df.step, df.episode_reward, color=l.get_color())
        else:
            ax.plot(df.step, df.episode_reward, color=l.get_color())
        if 'gym' not in config.env_name and 'mbpo' not in config.env_name \
          and config.env_name != 'Humanoid-v2' and 'pets' not in config.env_name:
            ax.set_ylim(0, 1000)
    ax.set_xlabel('Interactions')
    ax.set_title('Reward')
    ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))

    if print_cfg:
        pprint(config)
    if print_overrides:
        o = OmegaConf.load(f'{root}/.hydra/overrides.yaml')
        pprint(o)

    fig.tight_layout()
    fig.subplots_adjust(top=0.9)
    if suptitle:
        fig.suptitle(suptitle, fontsize=20)
    else:
        fig.suptitle(root + ': ' + config.env_name, fontsize=20)

    if save:
        fig.savefig(save, transparent=True)
        os.system(f'convert -trim {save} {save}')

    return fig, axs