def __post_init__()

in densities.py [0:0]


    def __post_init__(self):
        self.modes = []
        locs = [
            jnp.array([0.3, 1., 1.]),
            jnp.array([0.3, -1., 1.]),
            jnp.array([0.3, 1., -1.]),
            jnp.array([0.3, -1., -1.]),
        ]
        locs = [self.manifold.projx(loc) for loc in locs]
        scale = jnp.full(self.manifold.D-1, .3)
        self.dists = [
            WrappedNormal(manifold=self.manifold, loc=loc, scale=scale)
            for loc in locs
        ]