in flows.py [0:0]
def __call__(self, xs):
single = xs.ndim == 1
if single:
xs = jnp.expand_dims(xs, 0)
assert xs.ndim == 2
assert xs.shape[1] == self.manifold.D
F = 0.
for i, (mu, alpha, w) in enumerate(
zip(self.mus, self.alphas, self.ws)):
mu = self.manifold.projx(mu.T)
mu = mu.T
costs = self.manifold.cost(xs, mu) + alpha
w = jnp.exp(-w**2)[0]
if self.cost_gamma is not None and self.cost_gamma > 0.:
mincosts = self.cost_gamma * logsumexp(
-costs/self.cost_gamma, axis = 1)
else:
mincosts = - jnp.min(costs, 1)
F = w * nn.relu(F) + (1-w) * mincosts
if self.min_zero_gamma is not None and self.min_zero_gamma > 0.:
Fz = jnp.stack((F, jnp.zeros_like(F)), axis=-1)
F = self.min_zero_gamma * logsumexp(
-Fz/self.min_zero_gamma, axis=-1)
if single:
F = jnp.squeeze(F, 0)
return F