in parse.py [0:0]
def plot_min_acc_dist(df, run_groups, n):
dfs = []
for idx, _ in df.iterrows():
dfs.append(run_groups.get_group(idx)["min_acc_te"])
df = pd.concat(dfs).sort_index(level="Groups")
df = df.reindex(
["CelebA", "Waterbirds", "MultiNLI", "CivilComments"], level="dataset"
).reset_index()
sns.set(style="whitegrid", context="talk", font="Times New Roman")
g = sns.catplot(
data=df,
x="method",
y="min_acc_te",
col="dataset",
kind="box",
sharex=True,
sharey=False,
height=4.5,
)
for ax in g.fig.axes:
ax.tick_params(axis="x", labelrotation=45)
g.set_axis_labels("Method", "worst-group-acc")
g.set_titles(col_template="{col_name}")
g.tight_layout()
plt.savefig(f"figures/worst_group_acc_dist_dataset_{n}.pdf", dpi=300)
plt.savefig(f"figures/worst_group_acc_dist_dataset_{n}.png", dpi=300)
plt.figure()
g = sns.catplot(data=df, x="method", y="min_acc_te", kind="box", height=5.5)
for ax in g.fig.axes:
ax.tick_params(axis="x", labelrotation=45)
g.set_axis_labels("Method", "worst-group-acc")
g.tight_layout()
plt.savefig(f"figures/worst_group_acc_dist_{n}.pdf", dpi=300)
plt.savefig(f"figures/worst_group_acc_dist_{n}.png", dpi=300)