def forward()

in models/s2s_big_baseline.py [0:0]


    def forward(self, frames, config, use_prior, use_mean=False):

        stored_vars = []
        n_steps = config['n_steps']
        n_ctx = config['n_ctx']

        # Encode frames for latents and renderer
        sto_emb = self.sto_emb(frames)
        det_emb = self.det_emb(frames)

        # Get prior and posterior
        q_dists = self.posterior(sto_emb, sto_emb, use_mean=use_mean)
        p_dists = self.prior(sto_emb, sto_emb, q_dists, use_mean=use_mean)

        # Latent samples
        zs = []
        if use_prior:
            for (_, _, z0, _, _) in p_dists:
                zs.append(z0)
        else:
            for (_, _, _, zk, _) in q_dists:
                zs.append(zk)

        # Render frames
        preds = self.render(det_emb[-1][:, n_ctx - 1:-1], zs, det_emb)

        return (preds, p_dists, q_dists), stored_vars