in vae.py [0:0]
def __init__(self, H, res, mixin, n_blocks):
super().__init__()
self.base = res
self.mixin = mixin
self.H = H
self.widths = get_width_settings(H.width, H.custom_width_str)
width = self.widths[res]
use_3x3 = res > 2
cond_width = int(width * H.bottleneck_multiple)
self.zdim = H.zdim
self.enc = Block(width * 2, cond_width, H.zdim * 2, residual=False, use_3x3=use_3x3)
self.prior = Block(width, cond_width, H.zdim * 2 + width, residual=False, use_3x3=use_3x3, zero_last=True)
self.z_proj = get_1x1(H.zdim, width)
self.z_proj.weight.data *= np.sqrt(1 / n_blocks)
self.resnet = Block(width, cond_width, width, residual=True, use_3x3=use_3x3)
self.resnet.c4.weight.data *= np.sqrt(1 / n_blocks)
self.z_fn = lambda x: self.z_proj(x)