def train_kl()

in main.py [0:0]


    def train_kl(self):
        @jax.jit
        def loss(params, base_samples, base_log_probs):
            z, ldjs = self.flow.apply(params, base_samples)
            loss =  (base_log_probs - ldjs -
                     self.target.log_prob(z)).mean()
            return loss

        @jax.jit
        def update(optimizer, base_samples, base_log_probs):
            l, grads = jax.value_and_grad(loss)(
                optimizer.target, base_samples, base_log_probs)
            optimizer = optimizer.apply_gradient(grads)
            return l, optimizer

        logf, writer = self._init_logging()

        times = []
        if self.iter == 0:
            model_samples, ldjs = self.flow.apply(
                self.optimizer.target, self.base_samples)
            self.manifold.plot_samples(
                model_samples, save=f'{self.iter:06d}.png')

            self.manifold.plot_density(self.target.log_prob, 'target.png')

        while self.iter < self.cfg.iterations:
            start = time.time()
            self.key, subkey = jax.random.split(self.key)
            base_samples = self.base.sample(subkey, self.cfg.batch_size)
            base_log_probs = self.base.log_prob(base_samples)
            l, self.optimizer = update(
                self.optimizer, base_samples, base_log_probs)

            times.append(time.time() - start)
            self.iter += 1
            if self.iter % self.cfg.log_frequency == 0:
                l = loss(self.optimizer.target,
                         self.base_samples, self.base_log_probs)

                model_samples, ldjs = self.flow.apply(
                    self.optimizer.target, self.base_samples)
                self.manifold.plot_samples(
                    model_samples, save=f'{self.iter:06d}.png')
                if not self.cfg.disable_evol_plots:
                    for i, t in enumerate(jnp.linspace(0.1,1,11)):
                        model_samples, ldjs = self.flow.apply(
                        self.optimizer.target, self.base_samples, t = t)
                        self.manifold.plot_samples(
                            model_samples,
                            save=f'{self.iter:06d}_{i}.png')


                log_prob = self.base_log_probs - ldjs
                _,  kl, ess = kl_ess(
                    log_prob, self.target.log_prob(model_samples))
                ess = ess / self.cfg.eval_samples * 100
                msg = "Iter {} | Loss {:.3f} | KL {:.3f} | ESS {:.2f}% | {:.2e}s/it"
                print(msg.format(
                    self.iter, l, kl, ess, jnp.mean(jnp.array(times))))
                writer.writerow({
                    'iter': self.iter, 'loss': l, 'kl': kl, 'ess': ess
                })
                logf.flush()
                self.save('latest')

                times = []