def __call__()

in flows.py [0:0]


    def __call__(self, xs):
        single = xs.ndim == 1
        if single:
            xs = jnp.expand_dims(xs, 0)

        assert xs.ndim == 2
        assert xs.shape[1] == self.manifold.D
        n_batch = xs.shape[0]

        betas = nn.softplus(self.betas)
        mus = self.mus / jnp.linalg.norm(self.mus, axis=0, keepdims=True)
        alphas = nn.softmax(self.alphas)

        F = jnp.sum(
            (alphas/betas)*jnp.exp(betas * (jnp.matmul(xs, mus) - 1)),
            axis=-1
        )
        if single:
            F = jnp.squeeze(F, 0)

        return F