def plot_nenvs_dimspu()

in scripts/plot_results.py [0:0]


def plot_nenvs_dimspu(df_nenvs, df_dimspu, dirname, file_name='', save=True, block=False):
    fig, axs = plt.subplots(2, 6, figsize=(13, 5))
    axs = axs.flat
    counter = 0 
    # Top: nenvs
    datasets = df_nenvs["dataset"].unique()
    for id_d, dataset in zip(range(len(datasets)), sorted(datasets)):
        df_d =  df_nenvs[df_nenvs["dataset"] == dataset]
        models = df_d["model"].unique()
        legends = []
        for id_m, model in zip(range(len(models)), sorted(models)):
            df_d_m = df_d[df_d["model"] == model].sort_values(by="n_envs")
            legend, =  axs[id_d].plot(df_d_m["n_envs"]/5, df_d_m["mean"],
                color=f'C{id_m}',
                label=model,
                linewidth=2)
            top = (df_d_m["mean"]+df_d_m["std"]/2).to_numpy()
            bottom = (df_d_m["mean"]-df_d_m["std"]/2).to_numpy()
            xs = np.arange(2, 11) / 5
            axs[id_d].fill_between(xs, bottom, top, facecolor=f'C{id_m}', alpha=0.2)
            legends.append(legend)
        
        axs[id_d].set_xlabel(r'$\delta_{\rm env}$')
        axs[id_d].set_title(dataset)
        axs[id_d].set_ylim(bottom=-0.005)
        axs[id_d].set_xlim(left=0.4, right=2)
        counter += 1

    # Bottom: dimspu
    datasets = df_dimspu["dataset"].unique()
    for id_d, dataset in zip(range(counter, counter+len(datasets)), sorted(datasets)):
        df_d =  df_dimspu[df_dimspu["dataset"] == dataset]
        models = df_d["model"].unique()
        legends = []
        for id_m, model in zip(range(len(models)), sorted(models)):
            df_d_m = df_d[df_d["model"] == model].sort_values(by="dim_spu")
            legend, =  axs[id_d].plot(df_d_m["dim_spu"]/5, df_d_m["mean"],
                color=f'C{id_m}',
                label=model,
                linewidth=2)
            top = (df_d_m["mean"]+df_d_m["std"]/2).to_numpy()
            bottom = (df_d_m["mean"]-df_d_m["std"]/2).to_numpy()
            xs = np.arange(0, 11) / 5
            axs[id_d].fill_between(xs, bottom, top, facecolor=f'C{id_m}', alpha=0.2)
            legends.append(legend)
        
        axs[id_d].set_xlabel(r'$\delta_{\rm spu}$')
        axs[id_d].set_title(dataset)
        axs[id_d].set_ylim(bottom=-0.005)
        axs[id_d].set_xlim(left=0, right=2)


    axs[0].set_ylabel("Test error")
    axs[6].set_ylabel("Test error")
    plt.tight_layout(pad=0)
    plt.legend(handles=legends,
            ncol=6,
            loc="lower center",
            bbox_to_anchor=(-2.8, -0.7))

    if save:
        fig_dirname = "figs/" + dirname
        os.makedirs(fig_dirname, exist_ok=True)
        models = '_'.join(models)
        plt.savefig(fig_dirname + file_name + '.pdf',
                    format='pdf', bbox_inches='tight')
    if block:
        plt.show(block=False)
        input('Press to close')
        plt.close('all')