def plot_x1_reconstructions()

in plot.py [0:0]


def plot_x1_reconstructions(pairs, model, indices, train_set, save_name):
    """
    Plots sample x2 reconstructions based on indices
    
    Args:
        pairs (datasets.Pairs): contains x1, x2, and params.
        model (function): callable f(x1) = x1_reconstruction
        indices (list of ints): indices for samples to plot
        train_set (bool): if true title is plotted with train otherwise test.
        save_name (str): indicates path where images should be saved. 
    """
    title = "Training Reconstructions" if train_set else "Test Reconstructions"
    fig, axs = plt.subplots(len(indices), 2, figsize=(5, 12))
    fig.suptitle(title, fontsize=16)

    for i, sample_idx in enumerate(indices):
        x1, x2, params = pairs[sample_idx]
        n_pixels = x1.shape[1]

        x1_reconstruction = model(x1.unsqueeze(0)).cpu().detach().numpy()

        axs[i][0].imshow(x1.squeeze())
        axs[i][0].set_title("x1")

        axs[i][1].imshow(x1_reconstruction.reshape(n_pixels, n_pixels))
        axs[i][1].set_title("x1 reconstruction")

    if save_name:
        plt.savefig(f"{save_name}.png", dpi=300, bbox_inches="tight")
        plt.close()
    else:
        plt.show()