in manifolds.py [0:0]
def plot_samples(self, model_samples, save='t.png'):
theta1 = utils.S1euclideantospherical(model_samples[:,:2])
theta2 = utils.S1euclideantospherical(model_samples[:,2:])
x, y, z = utils.productS1toTorus(theta1, theta2)
data = jnp.stack((x, y, z), 1)
estimated_density = gaussian_kde(
data.T, 0.2)
x_grid, y_grid, z_grid = utils.productS1toTorus(self.tp[:,0], self.tp[:,1])
grid = jnp.stack((x_grid, y_grid, z_grid), 1)
probas_grid = estimated_density(grid.T)
fig = plt.figure()
ax = Axes3D(fig)
#TODO: fix this - I negate become the mode is at the bottom of the torus in unimodal density
ax.scatter(-x_grid, -y_grid, -z_grid, alpha = 0.2, c = probas_grid)
ax.set_xlim(-1,1)
ax.set_ylim(-1,1)
ax.set_zlim(-1,1)
plt.axis('off')
plt.savefig(save)