def plot_rotations_translations()

in plot.py [0:0]


def plot_rotations_translations(X, model, n_transformations, n_rot, n_x, n_y, save_name=None):
    degree_sign = "\N{DEGREE SIGN}"
    n_samples = X.shape[0]

    fig, axs = plt.subplots(n_samples, n_transformations + 2, figsize=(16, int(12/5.*len(X))))

    for sample_i, x1 in enumerate(X):
        axs[sample_i, 0].imshow(x1.squeeze())
        axs[sample_i, 0].set_title("original", fontsize=16)
        axs[sample_i, 0].set_xticks([])
        axs[sample_i, 0].set_yticks([])
        x1 = x1.to(model.device)
        transformation_params = [t for t in transformations.get_transform_params(n_rot, n_x, n_y, (1.0, ))]
        z = model.encoder(x1)
        angle = None
        shift_x = None
        shift_y = None
        
        t_list = []
        i = 0
        for _, t in enumerate(range(n_transformations+1)):
            j = np.random.randint(len(transformation_params))
            param = transformation_params[j]
            
            if not t in t_list:
                shifts = model.return_shifts([param])
                z_transformed = model.transform(z, shifts)
                x2_reconstruction = model.decoder(z_transformed).detach().cpu().numpy()

                axs[sample_i, i + 1].imshow(x2_reconstruction.squeeze())
                axs[sample_i, i + 1].set_title(f"{param.angle:0.0f}{degree_sign}\n{param.shift_x:0.0f},{param.shift_y:0.0f}", fontsize=16)
                axs[sample_i, i + 1].set_xticks([])
                axs[sample_i, i + 1].set_yticks([])
                angle = param.angle
                shift_x = param.shift_x
                shift_y = param.shift_y
                i += 1
            if i+1 >= n_transformations + 2:
                break
    if save_name:
        plt.savefig(save_name, bbox_inches="tight", dpi=300)
        plt.close()
    else:
        plt.show()