in flows.py [0:0]
def __call__(self, xs):
single = xs.ndim == 1
if single:
xs = jnp.expand_dims(xs, 0)
assert xs.ndim == 2
assert xs.shape[1] == self.manifold.D
n_batch = xs.shape[0]
betas = nn.softplus(self.betas)
mus = self.mus / jnp.linalg.norm(self.mus, axis=0, keepdims=True)
alphas = nn.softmax(self.alphas)
F = jnp.sum(
(alphas/betas)*jnp.exp(betas * (jnp.matmul(xs, mus) - 1)),
axis=-1
)
if single:
F = jnp.squeeze(F, 0)
return F