def train_likelihood()

in main.py [0:0]


    def train_likelihood(self):
        @jax.jit
        def logprob(params, target_samples, t = 1):
            zs, ldjs = self.flow.apply(params, target_samples, t = t)
            log_prob = ldjs + self.base.log_prob(zs)
            return log_prob

        @jax.jit
        def loss(params, target_samples):
            return -logprob(params, target_samples).mean()

        @jax.jit
        def update(optimizer, target_samples):
            l, grads = jax.value_and_grad(loss)(
                optimizer.target, target_samples)
            optimizer = optimizer.apply_gradient(grads)
            return l, optimizer

        target_sample_jit = jax.jit(self.target.sample, static_argnums=(1,))
        base_sample_jit = jax.jit(self.base.sample, static_argnums=(1,))

        logf, writer = self._init_logging()

        times = []

        if self.iter == 0 and not self.cfg.disable_init_plots:
            model_samples, ldjs = self.flow.apply(
                self.optimizer.target, self.eval_target_samples)
            try:
                self.manifold.plot_density(
                 self.target.log_prob, save=f'target_density.png')
            except:
                pass
            self.manifold.plot_samples(
                self.eval_target_samples, save=f'target_samples.png')
            self.manifold.plot_samples(
                base_sample_jit(self.key, self.cfg.eval_samples),
                save=f'base_samples.png')
            self.manifold.plot_density(
                self.base.log_prob, save=f'base_density.png')
            self.manifold.plot_samples(
                model_samples, save=f'samples_{self.iter:06d}.png')
            self.manifold.plot_density(
                functools.partial(logprob, self.optimizer.target),
                save=f'density_{self.iter:06d}.png')
            if not self.cfg.disable_evol_plots:
                for i, t in enumerate(jnp.linspace(0.1,1,11)):
                    self.manifold.plot_density(
                        functools.partial(logprob, self.optimizer.target, t = t),
                        save=f'density_{self.iter:06d}_{i}.png')



        while self.iter < self.cfg.iterations:
            start = time.time()
            self.key, subkey = jax.random.split(self.key)
            target_samples = target_sample_jit(subkey, self.cfg.batch_size)
            l, self.optimizer = update(self.optimizer, target_samples)

            times.append(time.time() - start)
            self.iter += 1
            if self.iter % self.cfg.log_frequency == 0:
                l = loss(self.optimizer.target, self.eval_target_samples)
                model_samples, ldjs = self.flow.apply(
                    self.optimizer.target, self.eval_target_samples)
                self.manifold.plot_samples(
                    model_samples, save=f'samples_{self.iter:06d}.png')
                self.manifold.plot_density(
                    functools.partial(logprob, self.optimizer.target),
                    save=f'density_{self.iter:06d}.png')
                if not self.cfg.disable_evol_plots:
                    for i, t in enumerate(jnp.linspace(0.1,1,10)):
                        self.manifold.plot_density(
                            functools.partial(logprob, self.optimizer.target, t = t),
                            save=f'density_{self.iter:06d}_{i}.png')



                msg = "Iter {} | Loss {:.3f} | {:.2e}s/it"
                print(msg.format(
                    self.iter, l, jnp.mean(jnp.array(times))))
                writer.writerow({
                    'iter': self.iter, 'loss': l,
                })
                logf.flush()
                self.save('latest')
                times = []