def plot_min_acc_evol()

in parse.py [0:0]


def plot_min_acc_evol(best_df, all_runs, filename):
    df = []
    all_runs_groups = all_runs.groupby(best_df.index.names)

    for idx, _ in best_df.iterrows():
        df.append(all_runs_groups.get_group(idx))
    df = (
        pd.concat(df)
        .sort_index()
        .reindex(["CelebA", "Waterbirds", "MultiNLI", "CivilComments"], level="dataset")
    )

    groups = df.groupby(
        ["dataset", "method", "hparams_seed", "init_seed", "Groups", "#HP"]
    )
    windows = {
        "CelebA": 5,
        "Waterbirds": 10,
    }
    dfs = []
    for group, df_group in groups:
        if group[0] in windows:
            dfs.append(df_group.rolling(window=windows[group[0]]).mean())
        else:
            dfs.append(df_group)
    df = pd.concat(dfs)
    plt.rc("font", size=11)
    df = (
        df.melt(
            value_vars=["min_acc_te", "min_acc_tr"],
            var_name="phase",
            value_name="worst-group-acc",
            ignore_index=False,
        )
        .replace({"min_acc_te": "test", "min_acc_tr": "train"})
        .reset_index()
    )

    sns.set_theme(context="talk", style="white", font="Times New Roman")

    scale = 1
    # plt.figure(figsize=(scale * 8, scale * 11))

    g = sns.relplot(
        data=df,
        x="epoch",
        y="worst-group-acc",
        hue="method",
        style="phase",
        kind="line",
        row="Groups",
        col="dataset",
        height=scale * 3.5,
        aspect=1,
        facet_kws=dict(sharex=False, sharey=False, margin_titles=True),
        alpha=0.7,
    )
    g.set_axis_labels("epoch", "worst-group-acc")
    g.set_titles(row_template="Groups = {row_name}", col_template="{col_name}")
    # g.add_legend(loc="lower center", ncol=4)
    g.tight_layout()
    plt.savefig(f"figures/{filename}.pdf", dpi=300)
    plt.savefig(f"figures/{filename}.png", dpi=300)