in notebooks/utils.py [0:0]
def plot_per_cls_perf(run_infos_all: list,
names: list,
metrics: list = ['vrec5_per_cls', 'nrec5_per_cls'],
cls_types: list = ['verb', 'noun'],
show_topn: int = 10,
xticks_rotation: float = 0,
show_subset: callable = None,
outfpath: str = 'figs/improved/'):
"""
Args:
run_infos_all: [[(cfg, sweep_id), (cfg, sweep_id)...],
[(cfg, sweep_id), (cfg, sweep_id)...], ...]
names: The name for each run_info group
metrics: There will be 1 graph for each
"""
assert len(run_infos_all) == len(names)
assert len(metrics) == len(cls_types)
final_accs = {cls_type: [] for cls_type in cls_types}
for i, run_infos in enumerate(tqdm(run_infos_all, desc='Reading acc')):
for run_id, run_info in enumerate(run_infos):
cfg_fpath, sweep_id = run_info
all_accuracies, _, dataset = get_epic_marginalize_verb_noun(
(cfg_fpath, sweep_id))
for metric, cls_type in zip(metrics, cls_types):
accuracies = all_accuracies[metric]
assert isinstance(accuracies,
dict), 'Supports per-class for now'
classes = operator.attrgetter(f'{cls_type}_classes')(dataset)
cls_id_to_name = {v: k for k, v in classes.items()}
for cls_id, score in accuracies.items():
final_accs[cls_type].append({
'method':
names[i],
'run_id':
run_id,
'cls_name':
cls_id_to_name[cls_id],
'accuracy':
score,
})
for cls_type in final_accs:
accs = pd.DataFrame(final_accs[cls_type])
# Print logs
for method in names:
for run_id in accs.run_id.unique():
this_acc = (accs[accs.method == method][
accs.run_id == run_id].accuracy.mean())
print(f'Check {method} {run_id}: {this_acc}')
mean_acc_by_cls = accs.groupby(['method',
'cls_name']).mean().reset_index()
first_col = mean_acc_by_cls[mean_acc_by_cls.method == names[0]]
last_col = mean_acc_by_cls[mean_acc_by_cls.method == names[-1]]
merged = first_col[['cls_name', 'accuracy'
]].merge(last_col[['cls_name', 'accuracy']],
on='cls_name',
how='outer',
suffixes=['_first', '_last'])
# get the largest gains
gains = (merged['accuracy_last'] -
merged['accuracy_first']).sort_values()
gained_labels = merged.loc[gains.index].cls_name.tolist()
if show_subset is not None:
gained_labels = [el for el in gained_labels if show_subset(el)]
gained_labels = gained_labels[-show_topn:]
accs_largegains = accs[accs.cls_name.isin(gained_labels)]
fig = plt.figure(num=None,
figsize=(2 * len(gained_labels), 4),
dpi=300)
ax = sns.barplot(x='cls_name',
y='accuracy',
hue='method',
data=accs_largegains,
order=gained_labels,
errwidth=1.0)
ax.set_xlabel('Classes')
ax.set_ylabel('Recall @ 5')
ax.set_xticklabels(ax.get_xticklabels(),
rotation=xticks_rotation,
ha='center')
plt.show()
save_graph(fig, os.path.join(outfpath, cls_type + '.pdf'))