def plot_min_acc_dist()

in parse.py [0:0]


def plot_min_acc_dist(df, run_groups, n):
    dfs = []
    for idx, _ in df.iterrows():
        dfs.append(run_groups.get_group(idx)["min_acc_te"])
    df = pd.concat(dfs).sort_index(level="Groups")
    df = df.reindex(
        ["CelebA", "Waterbirds", "MultiNLI", "CivilComments"], level="dataset"
    ).reset_index()
    sns.set(style="whitegrid", context="talk", font="Times New Roman")
    g = sns.catplot(
        data=df,
        x="method",
        y="min_acc_te",
        col="dataset",
        kind="box",
        sharex=True,
        sharey=False,
        height=4.5,
    )
    for ax in g.fig.axes:
        ax.tick_params(axis="x", labelrotation=45)
    g.set_axis_labels("Method", "worst-group-acc")
    g.set_titles(col_template="{col_name}")
    g.tight_layout()
    plt.savefig(f"figures/worst_group_acc_dist_dataset_{n}.pdf", dpi=300)
    plt.savefig(f"figures/worst_group_acc_dist_dataset_{n}.png", dpi=300)

    plt.figure()
    g = sns.catplot(data=df, x="method", y="min_acc_te", kind="box", height=5.5)
    for ax in g.fig.axes:
        ax.tick_params(axis="x", labelrotation=45)
    g.set_axis_labels("Method", "worst-group-acc")
    g.tight_layout()
    plt.savefig(f"figures/worst_group_acc_dist_{n}.pdf", dpi=300)
    plt.savefig(f"figures/worst_group_acc_dist_{n}.png", dpi=300)