in main.py [0:0]
def train_kl(self):
@jax.jit
def loss(params, base_samples, base_log_probs):
z, ldjs = self.flow.apply(params, base_samples)
loss = (base_log_probs - ldjs -
self.target.log_prob(z)).mean()
return loss
@jax.jit
def update(optimizer, base_samples, base_log_probs):
l, grads = jax.value_and_grad(loss)(
optimizer.target, base_samples, base_log_probs)
optimizer = optimizer.apply_gradient(grads)
return l, optimizer
logf, writer = self._init_logging()
times = []
if self.iter == 0:
model_samples, ldjs = self.flow.apply(
self.optimizer.target, self.base_samples)
self.manifold.plot_samples(
model_samples, save=f'{self.iter:06d}.png')
self.manifold.plot_density(self.target.log_prob, 'target.png')
while self.iter < self.cfg.iterations:
start = time.time()
self.key, subkey = jax.random.split(self.key)
base_samples = self.base.sample(subkey, self.cfg.batch_size)
base_log_probs = self.base.log_prob(base_samples)
l, self.optimizer = update(
self.optimizer, base_samples, base_log_probs)
times.append(time.time() - start)
self.iter += 1
if self.iter % self.cfg.log_frequency == 0:
l = loss(self.optimizer.target,
self.base_samples, self.base_log_probs)
model_samples, ldjs = self.flow.apply(
self.optimizer.target, self.base_samples)
self.manifold.plot_samples(
model_samples, save=f'{self.iter:06d}.png')
if not self.cfg.disable_evol_plots:
for i, t in enumerate(jnp.linspace(0.1,1,11)):
model_samples, ldjs = self.flow.apply(
self.optimizer.target, self.base_samples, t = t)
self.manifold.plot_samples(
model_samples,
save=f'{self.iter:06d}_{i}.png')
log_prob = self.base_log_probs - ldjs
_, kl, ess = kl_ess(
log_prob, self.target.log_prob(model_samples))
ess = ess / self.cfg.eval_samples * 100
msg = "Iter {} | Loss {:.3f} | KL {:.3f} | ESS {:.2f}% | {:.2e}s/it"
print(msg.format(
self.iter, l, kl, ess, jnp.mean(jnp.array(times))))
writer.writerow({
'iter': self.iter, 'loss': l, 'kl': kl, 'ess': ess
})
logf.flush()
self.save('latest')
times = []