in vae.py [0:0]
def forward(self, activations, get_latents=False):
stats = []
xs = {a.shape[2]: a for a in self.bias_xs}
for block in self.dec_blocks:
xs, block_stats = block(xs, activations, get_latents=get_latents)
stats.append(block_stats)
xs[self.H.image_size] = self.final_fn(xs[self.H.image_size])
return xs[self.H.image_size], stats