def plot_per_cls_perf()

in notebooks/utils.py [0:0]


def plot_per_cls_perf(run_infos_all: list,
                      names: list,
                      metrics: list = ['vrec5_per_cls', 'nrec5_per_cls'],
                      cls_types: list = ['verb', 'noun'],
                      show_topn: int = 10,
                      xticks_rotation: float = 0,
                      show_subset: callable = None,
                      outfpath: str = 'figs/improved/'):
    """
    Args:
        run_infos_all: [[(cfg, sweep_id), (cfg, sweep_id)...],
                        [(cfg, sweep_id), (cfg, sweep_id)...], ...]
        names: The name for each run_info group
        metrics: There will be 1 graph for each
    """
    assert len(run_infos_all) == len(names)
    assert len(metrics) == len(cls_types)
    final_accs = {cls_type: [] for cls_type in cls_types}
    for i, run_infos in enumerate(tqdm(run_infos_all, desc='Reading acc')):
        for run_id, run_info in enumerate(run_infos):
            cfg_fpath, sweep_id = run_info
            all_accuracies, _, dataset = get_epic_marginalize_verb_noun(
                (cfg_fpath, sweep_id))
            for metric, cls_type in zip(metrics, cls_types):
                accuracies = all_accuracies[metric]
                assert isinstance(accuracies,
                                  dict), 'Supports per-class for now'
                classes = operator.attrgetter(f'{cls_type}_classes')(dataset)
                cls_id_to_name = {v: k for k, v in classes.items()}
                for cls_id, score in accuracies.items():
                    final_accs[cls_type].append({
                        'method':
                        names[i],
                        'run_id':
                        run_id,
                        'cls_name':
                        cls_id_to_name[cls_id],
                        'accuracy':
                        score,
                    })
    for cls_type in final_accs:
        accs = pd.DataFrame(final_accs[cls_type])
        # Print logs
        for method in names:
            for run_id in accs.run_id.unique():
                this_acc = (accs[accs.method == method][
                    accs.run_id == run_id].accuracy.mean())
                print(f'Check {method} {run_id}: {this_acc}')
        mean_acc_by_cls = accs.groupby(['method',
                                        'cls_name']).mean().reset_index()
        first_col = mean_acc_by_cls[mean_acc_by_cls.method == names[0]]
        last_col = mean_acc_by_cls[mean_acc_by_cls.method == names[-1]]
        merged = first_col[['cls_name', 'accuracy'
                            ]].merge(last_col[['cls_name', 'accuracy']],
                                     on='cls_name',
                                     how='outer',
                                     suffixes=['_first', '_last'])
        # get the largest gains
        gains = (merged['accuracy_last'] -
                 merged['accuracy_first']).sort_values()
        gained_labels = merged.loc[gains.index].cls_name.tolist()
        if show_subset is not None:
            gained_labels = [el for el in gained_labels if show_subset(el)]
        gained_labels = gained_labels[-show_topn:]
        accs_largegains = accs[accs.cls_name.isin(gained_labels)]
        fig = plt.figure(num=None,
                         figsize=(2 * len(gained_labels), 4),
                         dpi=300)
        ax = sns.barplot(x='cls_name',
                         y='accuracy',
                         hue='method',
                         data=accs_largegains,
                         order=gained_labels,
                         errwidth=1.0)
        ax.set_xlabel('Classes')
        ax.set_ylabel('Recall @ 5')
        ax.set_xticklabels(ax.get_xticklabels(),
                           rotation=xticks_rotation,
                           ha='center')
        plt.show()
        save_graph(fig, os.path.join(outfpath, cls_type + '.pdf'))