def forward()

in models/s2s_convlstm.py [0:0]


    def forward(self, frames, config, use_prior, use_mean=False):

        bs, ts, cs, hs, ws = frames.shape

        stored_vars = []
        n_steps = config['n_steps']
        n_ctx = config['n_ctx']

        # Encode frames for latents and renderer
        pq_emb = self.latent_emb(frames)
        det_emb = self.det_emb(frames)

        # Get ctx emb for prior, posterior
        p_ctx = pq_emb[:, :n_ctx].contiguous()
        q_ctx = p_ctx

        # Get prior and posterior
        p_dists = self.prior(pq_emb[:, n_ctx - 1:-1].contiguous(), p_ctx, use_mean=use_mean)
        q_dists = self.posterior(pq_emb[:, n_ctx:].contiguous(), q_ctx, use_mean=use_mean)

        # Process prior and posterior
        ladj = None
        aux_p_dists = []
        for (means, logvars, z0, zk, ladj) in p_dists:
            means = means[:, -n_steps:].contiguous()
            logvars = logvars[:, -n_steps:].contiguous()
            z0 = z0[:, -n_steps:].contiguous()
            zk = zk[:, -n_steps:].contiguous()
            if ladj is not None:
                ladj = ladj[:, -n_steps:].contiguous()
            aux_p_dists.append([means, logvars, z0, zk, ladj])
        p_dists = aux_p_dists

        aux_q_dists = []
        for (means, logvars, z0, zk, ladj) in q_dists:
            means = means[:, -n_steps:].contiguous()
            logvars = logvars[:, -n_steps:].contiguous()
            z0 = z0[:, -n_steps:].contiguous()
            zk = zk[:, -n_steps:].contiguous()
            if ladj is not None:
                ladj = ladj[:, -n_steps:].contiguous()
            aux_q_dists.append([means, logvars, z0, zk, ladj])
        q_dists = aux_q_dists

        # Latent samples
        zs = []
        if use_prior:
            for (_, _, z0, _, _) in p_dists:
                zs.append(z0)
        else:
            for (_, _, _, zk, _) in q_dists:
                zs.append(zk)

        # Render frames
        preds = self.render(det_emb[-1][:, n_ctx - 1:-1], zs, det_emb)

        preds = preds[:, -n_steps:]

        return (preds, p_dists, q_dists), stored_vars