in vae.py [0:0]
def forward(self, x, x_target):
activations = self.encoder.forward(x)
px_z, stats = self.decoder.forward(activations)
distortion_per_pixel = self.decoder.out_net.nll(px_z, x_target)
rate_per_pixel = torch.zeros_like(distortion_per_pixel)
ndims = np.prod(x.shape[1:])
for statdict in stats:
rate_per_pixel += statdict['kl'].sum(dim=(1, 2, 3))
rate_per_pixel /= ndims
elbo = (distortion_per_pixel + rate_per_pixel).mean()
return dict(elbo=elbo, distortion=distortion_per_pixel.mean(), rate=rate_per_pixel.mean())