in main.py [0:0]
def train_likelihood(self):
@jax.jit
def logprob(params, target_samples, t = 1):
zs, ldjs = self.flow.apply(params, target_samples, t = t)
log_prob = ldjs + self.base.log_prob(zs)
return log_prob
@jax.jit
def loss(params, target_samples):
return -logprob(params, target_samples).mean()
@jax.jit
def update(optimizer, target_samples):
l, grads = jax.value_and_grad(loss)(
optimizer.target, target_samples)
optimizer = optimizer.apply_gradient(grads)
return l, optimizer
target_sample_jit = jax.jit(self.target.sample, static_argnums=(1,))
base_sample_jit = jax.jit(self.base.sample, static_argnums=(1,))
logf, writer = self._init_logging()
times = []
if self.iter == 0 and not self.cfg.disable_init_plots:
model_samples, ldjs = self.flow.apply(
self.optimizer.target, self.eval_target_samples)
try:
self.manifold.plot_density(
self.target.log_prob, save=f'target_density.png')
except:
pass
self.manifold.plot_samples(
self.eval_target_samples, save=f'target_samples.png')
self.manifold.plot_samples(
base_sample_jit(self.key, self.cfg.eval_samples),
save=f'base_samples.png')
self.manifold.plot_density(
self.base.log_prob, save=f'base_density.png')
self.manifold.plot_samples(
model_samples, save=f'samples_{self.iter:06d}.png')
self.manifold.plot_density(
functools.partial(logprob, self.optimizer.target),
save=f'density_{self.iter:06d}.png')
if not self.cfg.disable_evol_plots:
for i, t in enumerate(jnp.linspace(0.1,1,11)):
self.manifold.plot_density(
functools.partial(logprob, self.optimizer.target, t = t),
save=f'density_{self.iter:06d}_{i}.png')
while self.iter < self.cfg.iterations:
start = time.time()
self.key, subkey = jax.random.split(self.key)
target_samples = target_sample_jit(subkey, self.cfg.batch_size)
l, self.optimizer = update(self.optimizer, target_samples)
times.append(time.time() - start)
self.iter += 1
if self.iter % self.cfg.log_frequency == 0:
l = loss(self.optimizer.target, self.eval_target_samples)
model_samples, ldjs = self.flow.apply(
self.optimizer.target, self.eval_target_samples)
self.manifold.plot_samples(
model_samples, save=f'samples_{self.iter:06d}.png')
self.manifold.plot_density(
functools.partial(logprob, self.optimizer.target),
save=f'density_{self.iter:06d}.png')
if not self.cfg.disable_evol_plots:
for i, t in enumerate(jnp.linspace(0.1,1,10)):
self.manifold.plot_density(
functools.partial(logprob, self.optimizer.target, t = t),
save=f'density_{self.iter:06d}_{i}.png')
msg = "Iter {} | Loss {:.3f} | {:.2e}s/it"
print(msg.format(
self.iter, l, jnp.mean(jnp.array(times))))
writer.writerow({
'iter': self.iter, 'loss': l,
})
logf.flush()
self.save('latest')
times = []