in plot.py [0:0]
def plot_x1_reconstructions(pairs, model, indices, train_set, save_name):
"""
Plots sample x2 reconstructions based on indices
Args:
pairs (datasets.Pairs): contains x1, x2, and params.
model (function): callable f(x1) = x1_reconstruction
indices (list of ints): indices for samples to plot
train_set (bool): if true title is plotted with train otherwise test.
save_name (str): indicates path where images should be saved.
"""
title = "Training Reconstructions" if train_set else "Test Reconstructions"
fig, axs = plt.subplots(len(indices), 2, figsize=(5, 12))
fig.suptitle(title, fontsize=16)
for i, sample_idx in enumerate(indices):
x1, x2, params = pairs[sample_idx]
n_pixels = x1.shape[1]
x1_reconstruction = model(x1.unsqueeze(0)).cpu().detach().numpy()
axs[i][0].imshow(x1.squeeze())
axs[i][0].set_title("x1")
axs[i][1].imshow(x1_reconstruction.reshape(n_pixels, n_pixels))
axs[i][1].set_title("x1 reconstruction")
if save_name:
plt.savefig(f"{save_name}.png", dpi=300, bbox_inches="tight")
plt.close()
else:
plt.show()