in models/s2s_convlstm_baseline.py [0:0]
def forward(self, frames, config, use_prior, use_mean=False):
stored_vars = []
n_steps = config['n_steps']
n_ctx = config['n_ctx']
# Encode frames for latents and renderer
sto_emb = self.sto_emb(frames)
det_emb = self.det_emb(frames)
# Get prior and posterior
q_dists = self.posterior(sto_emb, sto_emb, use_mean=use_mean)
p_dists = self.prior(sto_emb, sto_emb, q_dists, use_mean=use_mean)
# Latent samples
zs = []
if use_prior:
for (_, _, z0, _, _) in p_dists:
zs.append(z0)
else:
for (_, _, _, zk, _) in q_dists:
zs.append(zk)
# Render frames
preds = self.render(det_emb[-1][:, n_ctx - 1:-1], zs, det_emb)
return (preds, p_dists, q_dists), stored_vars