in vae.py [0:0]
def sample(self, x, acts):
qm, qv = self.enc(torch.cat([x, acts], dim=1)).chunk(2, dim=1)
feats = self.prior(x)
pm, pv, xpp = feats[:, :self.zdim, ...], feats[:, self.zdim:self.zdim * 2, ...], feats[:, self.zdim * 2:, ...]
x = x + xpp
z = draw_gaussian_diag_samples(qm, qv)
kl = gaussian_analytical_kl(qm, pm, qv, pv)
return z, x, kl