in model.py [0:0]
def codec(hps):
def encoder(z, objective):
eps = []
for i in range(hps.n_levels):
z, objective = revnet2d(str(i), z, objective, hps)
if i < hps.n_levels-1:
z, objective, _eps = split2d("pool"+str(i), z, objective=objective)
eps.append(_eps)
return z, objective, eps
def decoder(z, eps=[None]*hps.n_levels, eps_std=None):
for i in reversed(range(hps.n_levels)):
if i < hps.n_levels-1:
z = split2d_reverse("pool"+str(i), z, eps=eps[i], eps_std=eps_std)
z, _ = revnet2d(str(i), z, 0, hps, reverse=True)
return z
return encoder, decoder