in densities.py [0:0]
def __post_init__(self):
self.modes = []
locs = [
jnp.array([0.3, 1., 1.]),
jnp.array([0.3, -1., 1.]),
jnp.array([0.3, 1., -1.]),
jnp.array([0.3, -1., -1.]),
]
locs = [self.manifold.projx(loc) for loc in locs]
scale = jnp.full(self.manifold.D-1, .3)
self.dists = [
WrappedNormal(manifold=self.manifold, loc=loc, scale=scale)
for loc in locs
]