def __post_init__()

in densities.py [0:0]


    def __post_init__(self):
        self.modes = []
        one = jnp.ones(3)
        oned = jnp.ones(3)
        oned = jax.ops.index_update(oned, jax.ops.index[2], -1.)
        locs = [one, -one, oned, -oned]
        locs = [self.manifold.projx(loc) for loc in locs]
        scale = jnp.full(self.manifold.D-1, .3)
        self.dists = [
            WrappedNormal(manifold=self.manifold, loc=loc, scale=scale)
            for loc in locs
        ]