in vae.py [0:0]
def build(self):
H = self.H
resos = set()
dec_blocks = []
self.widths = get_width_settings(H.width, H.custom_width_str)
blocks = parse_layer_string(H.dec_blocks)
for idx, (res, mixin) in enumerate(blocks):
dec_blocks.append(DecBlock(H, res, mixin, n_blocks=len(blocks)))
resos.add(res)
self.resolutions = sorted(resos)
self.dec_blocks = nn.ModuleList(dec_blocks)
self.bias_xs = nn.ParameterList([nn.Parameter(torch.zeros(1, self.widths[res], res, res)) for res in self.resolutions if res <= H.no_bias_above])
self.out_net = DmolNet(H)
self.gain = nn.Parameter(torch.ones(1, H.width, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, H.width, 1, 1))
self.final_fn = lambda x: x * self.gain + self.bias