in densities.py [0:0]
def log_prob(self, xs):
# TODO, support other spheres
assert xs.ndim == 2
n_batch, D = xs.shape
assert D == self.manifold.D
if self.manifold.D == 2:
SA = 2.*jnp.pi
elif self.manifold.D == 3:
SA = 4.*jnp.pi
else:
raise NotImplementedError()
return jnp.full([n_batch], jnp.log(1. / SA))