def eta_histogram()

in scripts/make_figures.py [0:0]


def eta_histogram(results_path, save_path, prefix, train, labels):
    """
    Histogram of samples by individual sample eta.
    """
    plt.clf()

    n_class = len(labels)
    colors = sns.cubehelix_palette(n_class, start=2, rot=0, dark=0, light=.5)

    plt.figure(figsize=(6, 6))
    etas = np.array(load_results(
        results_path, f"{prefix}_pca20.json")["etas"])
    targets = train["targets"].numpy()
    etas = [etas[targets == c] for c in range(n_class)]

    for c in range(n_class):
        plt.hist(
            etas[c], bins=80, color=colors[c], alpha=0.5, label=f"{labels[c]}")
        plt.axvline(
            etas[c].mean(), color=colors[c], linestyle='dashed', linewidth=2)
    plt.xlabel("Per sample $\eta$", fontsize=30)
    plt.xticks(fontsize=26)
    plt.ylabel("Number of samples", fontsize=30)
    plt.yticks(fontsize=26)
    plt.legend()
    plotting.savefig(os.path.join(save_path, f"{prefix}_eta_hist"))