def __call__()

in flows.py [0:0]


    def __call__(self, xs, t = 1):
        assert xs.ndim == 2
        n_batch = xs.shape[0]

        def dF_riemannian(xs):
            assert xs.ndim == 1
            dF = jax.jacfwd(self.potential)(xs)
            dF = self.manifold.tangent_projection(xs, dF)
            return dF

        def flow(xs):
            assert xs.ndim == 1
            dF = dF_riemannian(xs)
            z = self.manifold.exponential_map(xs, t * dF)
            return z

        def flow_jacobian(xs):
            assert xs.ndim == 1
            J = jax.jacfwd(flow)(xs)
            return J

        def flow_and_jac(xs):
            z = flow(xs)
            dF = dF_riemannian(xs)
            J = flow_jacobian(xs)
            return z, dF, J

        z, dF, J = jax.vmap(flow_and_jac)(xs)

        E = self.manifold.tangent_orthonormal_basis(xs, dF)
        JE = jnp.matmul(J, E)
        JETJE = jnp.einsum('nji,njk->nik', JE, JE)

        sign, logdet = jnp.linalg.slogdet(JETJE)
        logdet *= 0.5

        return z, logdet, sign