in flows.py [0:0]
def __call__(self, xs):
single = xs.ndim == 1
if single:
xs = jnp.expand_dims(xs, 0)
assert xs.ndim == 2
assert xs.shape[1] == self.manifold.D
n_batch = xs.shape[0]
mus = self.manifold.projx(self.mus.T)
mus = mus.T
costs = self.manifold.cost(xs, mus) + self.alphas
if self.cost_gamma is not None and self.cost_gamma > 0.:
F = self.cost_gamma * logsumexp(
-costs/self.cost_gamma, axis = 1)
else:
F = - jnp.min(costs, 1)
if self.min_zero_gamma is not None and self.min_zero_gamma > 0.:
Fz = jnp.stack((F, jnp.zeros_like(F)), axis=-1)
F = self.min_zero_gamma * logsumexp(
-Fz/self.min_zero_gamma, axis=-1)
if single:
F = jnp.squeeze(F, 0)
return F