def plot_density()

in manifolds.py [0:0]


    def plot_density(self, log_prob_fn, save='t.png'):
        euc1 = jnp.stack((jnp.cos(self.tp[:,0]), jnp.sin(self.tp[:,0])),1)
        euc2 = jnp.stack((jnp.cos(self.tp[:,1]), jnp.sin(self.tp[:,1])),1)
        prod_euc = jnp.concatenate((euc1,euc2),1)

        density = log_prob_fn(prod_euc)
        density = jnp.exp(density)

        x_grid, y_grid, z_grid = utils.productS1toTorus(self.tp[:,0], self.tp[:,1])
        grid = jnp.stack((x_grid, y_grid, z_grid), 1)

        fig = plt.figure()
        plt.savefig(save)
        ax = Axes3D(fig)
        #TODO: fix this - I negate become the mode is at the bottom of the torus in unimodal density
        ax.scatter(-x_grid, -y_grid, -z_grid, alpha = 0.2, c = density)
        ax.set_xlim(-1,1)
        ax.set_ylim(-1,1)
        ax.set_zlim(-1,1)
        plt.axis('off')

        plt.savefig(save)