def get()

in densities.py [0:0]


def get(manifold, name):
    if name == 'SphereBaseWrappedNormal':
        assert isinstance(manifold, Sphere)
        loc = manifold.zero()
        scale = jnp.full(manifold.D-1, .3)
        return WrappedNormal(manifold=manifold, loc=loc, scale=scale)
    elif name == 'LouSphereSingleMode':
        assert isinstance(manifold, Sphere)
        loc = manifold.projx(-jnp.ones(manifold.D))
        scale = jnp.full(manifold.D-1, .3)
        return WrappedNormal(manifold=manifold, loc=loc, scale=scale)
    elif 'Earth' in name:
            try:
                name, year = name.split('_')
                return getattr(sys.modules[__name__], name)(manifold=manifold, year = year)
            except:
                print(f"Error loading data class {name}")
                raise
    else:
        try:
            return getattr(sys.modules[__name__], name)(manifold=manifold)
        except:
            print(f"Error loading data class {name}")
            raise