in plot.py [0:0]
def plot_rotations_translations(X, model, n_transformations, n_rot, n_x, n_y, save_name=None):
degree_sign = "\N{DEGREE SIGN}"
n_samples = X.shape[0]
fig, axs = plt.subplots(n_samples, n_transformations + 2, figsize=(16, int(12/5.*len(X))))
for sample_i, x1 in enumerate(X):
axs[sample_i, 0].imshow(x1.squeeze())
axs[sample_i, 0].set_title("original", fontsize=16)
axs[sample_i, 0].set_xticks([])
axs[sample_i, 0].set_yticks([])
x1 = x1.to(model.device)
transformation_params = [t for t in transformations.get_transform_params(n_rot, n_x, n_y, (1.0, ))]
z = model.encoder(x1)
angle = None
shift_x = None
shift_y = None
t_list = []
i = 0
for _, t in enumerate(range(n_transformations+1)):
j = np.random.randint(len(transformation_params))
param = transformation_params[j]
if not t in t_list:
shifts = model.return_shifts([param])
z_transformed = model.transform(z, shifts)
x2_reconstruction = model.decoder(z_transformed).detach().cpu().numpy()
axs[sample_i, i + 1].imshow(x2_reconstruction.squeeze())
axs[sample_i, i + 1].set_title(f"{param.angle:0.0f}{degree_sign}\n{param.shift_x:0.0f},{param.shift_y:0.0f}", fontsize=16)
axs[sample_i, i + 1].set_xticks([])
axs[sample_i, i + 1].set_yticks([])
angle = param.angle
shift_x = param.shift_x
shift_y = param.shift_y
i += 1
if i+1 >= n_transformations + 2:
break
if save_name:
plt.savefig(save_name, bbox_inches="tight", dpi=300)
plt.close()
else:
plt.show()