in main.py [0:0]
def __init__(self, cfg):
self.cfg = cfg
self.work_dir = os.getcwd()
print(f'workspace: {self.work_dir}')
self.manifold = hydra.utils.instantiate(self.cfg.manifold)
self.base = densities.get(self.manifold, self.cfg.base)
self.target = densities.get(self.manifold, self.cfg.target)
self.key = jax.random.PRNGKey(self.cfg.seed)
self.flow = hydra.utils.instantiate(
self.cfg.flow, manifold=self.manifold)
self.key, k1, k2, k3, k4, k5 = jax.random.split(self.key, 6)
batch = self.base.sample(k1, self.cfg.batch_size)
init_params = self.flow.init(k2, batch)
self.base_samples = self.base.sample(k3, self.cfg.eval_samples)
self.base_log_probs = self.base.log_prob(self.base_samples)
if self.cfg.loss == 'likelihood':
self.eval_target_samples = self.target.sample(
k5, self.cfg.eval_samples)
optimizer_def = hydra.utils.instantiate(self.cfg.optim)
self.optimizer = optimizer_def.create(init_params)
self.iter = 0