in vae.py [0:0]
def forward(self, xs, activations, get_latents=False):
x, acts = self.get_inputs(xs, activations)
if self.mixin is not None:
x = x + F.interpolate(xs[self.mixin][:, :x.shape[1], ...], scale_factor=self.base // self.mixin)
z, x, kl = self.sample(x, acts)
x = x + self.z_fn(z)
x = self.resnet(x)
xs[self.base] = x
if get_latents:
return xs, dict(z=z.detach(), kl=kl)
return xs, dict(kl=kl)