in flows.py [0:0]
    def __call__(self, xs, t = 1):
        assert xs.ndim == 2
        n_batch = xs.shape[0]
        def dF_riemannian(xs):
            assert xs.ndim == 1
            dF = jax.jacfwd(self.potential)(xs)
            dF = self.manifold.tangent_projection(xs, dF)
            return dF
        def flow(xs):
            assert xs.ndim == 1
            dF = dF_riemannian(xs)
            z = self.manifold.exponential_map(xs, t * dF)
            return z
        def flow_jacobian(xs):
            assert xs.ndim == 1
            J = jax.jacfwd(flow)(xs)
            return J
        def flow_and_jac(xs):
            z = flow(xs)
            dF = dF_riemannian(xs)
            J = flow_jacobian(xs)
            return z, dF, J
        z, dF, J = jax.vmap(flow_and_jac)(xs)
        E = self.manifold.tangent_orthonormal_basis(xs, dF)
        JE = jnp.matmul(J, E)
        JETJE = jnp.einsum('nji,njk->nik', JE, JE)
        sign, logdet = jnp.linalg.slogdet(JETJE)
        logdet *= 0.5
        return z, logdet, sign