def view_images()

in scripts/make_figures.py [0:0]


def view_images(train, results_path, save_path, prefix):
    """
    View the most and least leaked images.
    """
    # sort etas by index
    etas = load_results(
        results_path, f"{prefix}_pca20.json")["etas"]
    sorted_etas = sorted(
        zip(etas, range(len(etas))), key=lambda x: x[0], reverse=True)

    ims = train["features"].squeeze()
    n_ims = 8
    f, axarr = plt.subplots(2, n_ims, figsize=(7, 2.2))
    f.subplots_adjust(wspace=0.05)
    for priv in [False, True]:
        for i in range(n_ims):
            ax = axarr[int(priv), i]
            idx = -(i + 1) if priv else i
            im = sorted_etas[idx][1]
            image = ims[im, ...]
            if image.ndim == 3:
                image = image.permute(1, 2, 0)
            ax.imshow(image, cmap='gray')
            ax.axis("off")
            title = "{:.1e}".format(sorted_etas[idx][0])
            ax.set_title(title, fontsize=14)
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
    plotting.savefig(os.path.join(save_path, f"{prefix}_images"))
    plt.close(f)