in densities.py [0:0]
def __post_init__(self):
self.modes = []
one = jnp.ones(3)
oned = jnp.ones(3)
oned = jax.ops.index_update(oned, jax.ops.index[2], -1.)
locs = [one, -one, oned, -oned]
locs = [self.manifold.projx(loc) for loc in locs]
scale = jnp.full(self.manifold.D-1, .3)
self.dists = [
WrappedNormal(manifold=self.manifold, loc=loc, scale=scale)
for loc in locs
]