in flows.py [0:0]
def __call__(self, xs, t = 1):
assert xs.ndim == 2
n_batch = xs.shape[0]
def dF_riemannian(xs):
assert xs.ndim == 1
dF = jax.jacfwd(self.potential)(xs)
dF = self.manifold.tangent_projection(xs, dF)
return dF
def flow(xs):
assert xs.ndim == 1
dF = dF_riemannian(xs)
z = self.manifold.exponential_map(xs, t * dF)
return z
def flow_jacobian(xs):
assert xs.ndim == 1
J = jax.jacfwd(flow)(xs)
return J
def flow_and_jac(xs):
z = flow(xs)
dF = dF_riemannian(xs)
J = flow_jacobian(xs)
return z, dF, J
z, dF, J = jax.vmap(flow_and_jac)(xs)
E = self.manifold.tangent_orthonormal_basis(xs, dF)
JE = jnp.matmul(J, E)
JETJE = jnp.einsum('nji,njk->nik', JE, JE)
sign, logdet = jnp.linalg.slogdet(JETJE)
logdet *= 0.5
return z, logdet, sign