in parse.py [0:0]
def plot_min_acc_evol(best_df, all_runs, filename):
df = []
all_runs_groups = all_runs.groupby(best_df.index.names)
for idx, _ in best_df.iterrows():
df.append(all_runs_groups.get_group(idx))
df = (
pd.concat(df)
.sort_index()
.reindex(["CelebA", "Waterbirds", "MultiNLI", "CivilComments"], level="dataset")
)
groups = df.groupby(
["dataset", "method", "hparams_seed", "init_seed", "Groups", "#HP"]
)
windows = {
"CelebA": 5,
"Waterbirds": 10,
}
dfs = []
for group, df_group in groups:
if group[0] in windows:
dfs.append(df_group.rolling(window=windows[group[0]]).mean())
else:
dfs.append(df_group)
df = pd.concat(dfs)
plt.rc("font", size=11)
df = (
df.melt(
value_vars=["min_acc_te", "min_acc_tr"],
var_name="phase",
value_name="worst-group-acc",
ignore_index=False,
)
.replace({"min_acc_te": "test", "min_acc_tr": "train"})
.reset_index()
)
sns.set_theme(context="talk", style="white", font="Times New Roman")
scale = 1
# plt.figure(figsize=(scale * 8, scale * 11))
g = sns.relplot(
data=df,
x="epoch",
y="worst-group-acc",
hue="method",
style="phase",
kind="line",
row="Groups",
col="dataset",
height=scale * 3.5,
aspect=1,
facet_kws=dict(sharex=False, sharey=False, margin_titles=True),
alpha=0.7,
)
g.set_axis_labels("epoch", "worst-group-acc")
g.set_titles(row_template="Groups = {row_name}", col_template="{col_name}")
# g.add_legend(loc="lower center", ncol=4)
g.tight_layout()
plt.savefig(f"figures/{filename}.pdf", dpi=300)
plt.savefig(f"figures/{filename}.png", dpi=300)