in models/s2s_convlstm.py [0:0]
def forward(self, frames, config, use_prior, use_mean=False):
bs, ts, cs, hs, ws = frames.shape
stored_vars = []
n_steps = config['n_steps']
n_ctx = config['n_ctx']
# Encode frames for latents and renderer
pq_emb = self.latent_emb(frames)
det_emb = self.det_emb(frames)
# Get ctx emb for prior, posterior
p_ctx = pq_emb[:, :n_ctx].contiguous()
q_ctx = p_ctx
# Get prior and posterior
p_dists = self.prior(pq_emb[:, n_ctx - 1:-1].contiguous(), p_ctx, use_mean=use_mean)
q_dists = self.posterior(pq_emb[:, n_ctx:].contiguous(), q_ctx, use_mean=use_mean)
# Process prior and posterior
ladj = None
aux_p_dists = []
for (means, logvars, z0, zk, ladj) in p_dists:
means = means[:, -n_steps:].contiguous()
logvars = logvars[:, -n_steps:].contiguous()
z0 = z0[:, -n_steps:].contiguous()
zk = zk[:, -n_steps:].contiguous()
if ladj is not None:
ladj = ladj[:, -n_steps:].contiguous()
aux_p_dists.append([means, logvars, z0, zk, ladj])
p_dists = aux_p_dists
aux_q_dists = []
for (means, logvars, z0, zk, ladj) in q_dists:
means = means[:, -n_steps:].contiguous()
logvars = logvars[:, -n_steps:].contiguous()
z0 = z0[:, -n_steps:].contiguous()
zk = zk[:, -n_steps:].contiguous()
if ladj is not None:
ladj = ladj[:, -n_steps:].contiguous()
aux_q_dists.append([means, logvars, z0, zk, ladj])
q_dists = aux_q_dists
# Latent samples
zs = []
if use_prior:
for (_, _, z0, _, _) in p_dists:
zs.append(z0)
else:
for (_, _, _, zk, _) in q_dists:
zs.append(zk)
# Render frames
preds = self.render(det_emb[-1][:, n_ctx - 1:-1], zs, det_emb)
preds = preds[:, -n_steps:]
return (preds, p_dists, q_dists), stored_vars