def __call__()

in flows.py [0:0]


    def __call__(self, xs):
        single = xs.ndim == 1
        if single:
            xs = jnp.expand_dims(xs, 0)

        assert xs.ndim == 2
        assert xs.shape[1] == self.manifold.D
        n_batch = xs.shape[0]

        mus = self.manifold.projx(self.mus.T)
        mus = mus.T

        costs = self.manifold.cost(xs, mus) + self.alphas

        if self.cost_gamma is not None and self.cost_gamma > 0.:
            F = self.cost_gamma * logsumexp(
                -costs/self.cost_gamma, axis = 1)
        else:
            F = - jnp.min(costs, 1)

        if self.min_zero_gamma is not None and self.min_zero_gamma > 0.:
            Fz = jnp.stack((F, jnp.zeros_like(F)), axis=-1)
            F = self.min_zero_gamma * logsumexp(
                -Fz/self.min_zero_gamma, axis=-1)

        if single:
            F = jnp.squeeze(F, 0)
        return F