in scripts/make_figures.py [0:0]
def eta_histogram(results_path, save_path, prefix, train, labels):
"""
Histogram of samples by individual sample eta.
"""
plt.clf()
n_class = len(labels)
colors = sns.cubehelix_palette(n_class, start=2, rot=0, dark=0, light=.5)
plt.figure(figsize=(6, 6))
etas = np.array(load_results(
results_path, f"{prefix}_pca20.json")["etas"])
targets = train["targets"].numpy()
etas = [etas[targets == c] for c in range(n_class)]
for c in range(n_class):
plt.hist(
etas[c], bins=80, color=colors[c], alpha=0.5, label=f"{labels[c]}")
plt.axvline(
etas[c].mean(), color=colors[c], linestyle='dashed', linewidth=2)
plt.xlabel("Per sample $\eta$", fontsize=30)
plt.xticks(fontsize=26)
plt.ylabel("Number of samples", fontsize=30)
plt.yticks(fontsize=26)
plt.legend()
plotting.savefig(os.path.join(save_path, f"{prefix}_eta_hist"))