in vae.py [0:0]
def forward_uncond(self, xs, t=None, lvs=None):
try:
x = xs[self.base]
except KeyError:
ref = xs[list(xs.keys())[0]]
x = torch.zeros(dtype=ref.dtype, size=(ref.shape[0], self.widths[self.base], self.base, self.base), device=ref.device)
if self.mixin is not None:
x = x + F.interpolate(xs[self.mixin][:, :x.shape[1], ...], scale_factor=self.base // self.mixin)
z, x = self.sample_uncond(x, t, lvs=lvs)
x = x + self.z_fn(z)
x = self.resnet(x)
xs[self.base] = x
return xs