in scripts/make_figures.py [0:0]
def view_images(train, results_path, save_path, prefix):
"""
View the most and least leaked images.
"""
# sort etas by index
etas = load_results(
results_path, f"{prefix}_pca20.json")["etas"]
sorted_etas = sorted(
zip(etas, range(len(etas))), key=lambda x: x[0], reverse=True)
ims = train["features"].squeeze()
n_ims = 8
f, axarr = plt.subplots(2, n_ims, figsize=(7, 2.2))
f.subplots_adjust(wspace=0.05)
for priv in [False, True]:
for i in range(n_ims):
ax = axarr[int(priv), i]
idx = -(i + 1) if priv else i
im = sorted_etas[idx][1]
image = ims[im, ...]
if image.ndim == 3:
image = image.permute(1, 2, 0)
ax.imshow(image, cmap='gray')
ax.axis("off")
title = "{:.1e}".format(sorted_etas[idx][0])
ax.set_title(title, fontsize=14)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plotting.savefig(os.path.join(save_path, f"{prefix}_images"))
plt.close(f)