in models/vrnn_hier.py [0:0]
def forward(self, frames, config, use_prior, use_mean=False, scale_var=1.):
stored_vars = []
n_steps = config['n_steps']
n_ctx = config['n_ctx']
# Encode frames for latents and renderer
emb = self.get_emb(frames)
# Get prior and posterior
q_dists = self.posterior(emb, use_mean=use_mean, scale_var=scale_var)
p_dists = self.prior(emb, q_dists, use_mean=use_mean, scale_var=scale_var)
# Latent samples
zs = []
if use_prior:
for (_, _, z0, _, _) in p_dists:
zs.append(z0)
else:
for (_, _, _, zk, _) in q_dists:
zs.append(zk)
# Render frames
preds = self.render(emb[-1][:, n_ctx - 1:-1], zs, emb)
return (preds, p_dists, q_dists), stored_vars