def __call__()

in flows.py [0:0]


    def __call__(self, xs):
        single = xs.ndim == 1
        if single:
            xs = jnp.expand_dims(xs, 0)

        assert xs.ndim == 2
        assert xs.shape[1] == self.manifold.D

        F = 0.
        for i, (mu, alpha, w) in enumerate(
                zip(self.mus, self.alphas, self.ws)):

            mu = self.manifold.projx(mu.T)
            mu = mu.T

            costs = self.manifold.cost(xs, mu) + alpha

            w = jnp.exp(-w**2)[0]

            if self.cost_gamma is not None and self.cost_gamma > 0.:
                mincosts =  self.cost_gamma * logsumexp(
                    -costs/self.cost_gamma, axis = 1)
            else:
                mincosts = - jnp.min(costs, 1)

            F = w * nn.relu(F) + (1-w) * mincosts


        if self.min_zero_gamma is not None and self.min_zero_gamma > 0.:
            Fz = jnp.stack((F, jnp.zeros_like(F)), axis=-1)
            F = self.min_zero_gamma * logsumexp(
                -Fz/self.min_zero_gamma, axis=-1)

        if single:
            F = jnp.squeeze(F, 0)
        return F