in manifolds.py [0:0]
def plot_density(self, log_prob_fn, save='t.png'):
euc1 = jnp.stack((jnp.cos(self.tp[:,0]), jnp.sin(self.tp[:,0])),1)
euc2 = jnp.stack((jnp.cos(self.tp[:,1]), jnp.sin(self.tp[:,1])),1)
prod_euc = jnp.concatenate((euc1,euc2),1)
density = log_prob_fn(prod_euc)
density = jnp.exp(density)
x_grid, y_grid, z_grid = utils.productS1toTorus(self.tp[:,0], self.tp[:,1])
grid = jnp.stack((x_grid, y_grid, z_grid), 1)
fig = plt.figure()
plt.savefig(save)
ax = Axes3D(fig)
#TODO: fix this - I negate become the mode is at the bottom of the torus in unimodal density
ax.scatter(-x_grid, -y_grid, -z_grid, alpha = 0.2, c = density)
ax.set_xlim(-1,1)
ax.set_ylim(-1,1)
ax.set_zlim(-1,1)
plt.axis('off')
plt.savefig(save)