in flows.py [0:0]
def __call__(self, orig_xs, debug=False, t = 1):
ldjs = 0.
all_xs = []
all_ldjs = []
all_ldj_signs = []
Fs = []
xs = orig_xs
for transform in self.transforms:
xs, ldj, ldj_sign = transform(xs, t = t)
if debug:
F = transform.potential(orig_xs)
all_xs.append(xs)
all_ldjs.append(ldj)
all_ldj_signs.append(ldj_sign)
Fs.append(F)
ldjs += ldj
if not debug:
return xs, ldjs
else:
all_xs = jnp.stack(all_xs)
all_ldjs = jnp.stack(all_ldjs)
all_ldj_signs = jnp.stack(all_ldj_signs)
return all_xs, all_ldjs, all_ldj_signs, Fs, ldjs