def __call__()

in flows.py [0:0]


    def __call__(self, orig_xs, debug=False, t = 1):
        ldjs = 0.
        all_xs = []
        all_ldjs = []
        all_ldj_signs = []
        Fs = []

        xs = orig_xs
        for transform in self.transforms:
            xs, ldj, ldj_sign = transform(xs, t = t)
            if debug:
                F = transform.potential(orig_xs)
                all_xs.append(xs)
                all_ldjs.append(ldj)
                all_ldj_signs.append(ldj_sign)
                Fs.append(F)
            ldjs += ldj

        if not debug:
            return xs, ldjs
        else:
            all_xs = jnp.stack(all_xs)
            all_ldjs = jnp.stack(all_ldjs)
            all_ldj_signs = jnp.stack(all_ldj_signs)
            return all_xs, all_ldjs, all_ldj_signs, Fs, ldjs