def plot_samples()

in manifolds.py [0:0]


    def plot_samples(self, model_samples, save='t.png'):
        theta1 = utils.S1euclideantospherical(model_samples[:,:2])
        theta2 = utils.S1euclideantospherical(model_samples[:,2:])

        x, y, z = utils.productS1toTorus(theta1, theta2)
        data = jnp.stack((x, y, z), 1)
        estimated_density = gaussian_kde(
                data.T, 0.2)

        x_grid, y_grid, z_grid = utils.productS1toTorus(self.tp[:,0], self.tp[:,1])
        grid = jnp.stack((x_grid, y_grid, z_grid), 1)
        probas_grid = estimated_density(grid.T)

        fig = plt.figure()
        ax = Axes3D(fig)
        #TODO: fix this - I negate become the mode is at the bottom of the torus in unimodal density
        ax.scatter(-x_grid, -y_grid, -z_grid, alpha = 0.2, c = probas_grid)
        ax.set_xlim(-1,1)
        ax.set_ylim(-1,1)
        ax.set_zlim(-1,1)
        plt.axis('off')
        plt.savefig(save)