def log_prob()

in densities.py [0:0]


    def log_prob(self, xs):
        # TODO, support other spheres
        assert xs.ndim == 2
        n_batch, D = xs.shape
        assert D == self.manifold.D

        if self.manifold.D == 2:
            SA = 2.*jnp.pi
        elif self.manifold.D == 3:
            SA = 4.*jnp.pi
        else:
            raise NotImplementedError()

        return jnp.full([n_batch], jnp.log(1. / SA))