def __init__()

in main.py [0:0]


    def __init__(self, cfg):
        self.cfg = cfg

        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.manifold = hydra.utils.instantiate(self.cfg.manifold)
        self.base = densities.get(self.manifold, self.cfg.base)
        self.target = densities.get(self.manifold, self.cfg.target)

        self.key = jax.random.PRNGKey(self.cfg.seed)

        self.flow = hydra.utils.instantiate(
            self.cfg.flow, manifold=self.manifold)
        self.key, k1, k2, k3, k4, k5 = jax.random.split(self.key, 6)
        batch = self.base.sample(k1, self.cfg.batch_size)
        init_params = self.flow.init(k2, batch)

        self.base_samples = self.base.sample(k3, self.cfg.eval_samples)
        self.base_log_probs = self.base.log_prob(self.base_samples)
        if self.cfg.loss == 'likelihood':
            self.eval_target_samples = self.target.sample(
                k5, self.cfg.eval_samples)

        optimizer_def = hydra.utils.instantiate(self.cfg.optim)
        self.optimizer = optimizer_def.create(init_params)

        self.iter = 0