models/s2s_big_baseline.py [231:389]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return skips

    def sto_emb(self, x):
        out = x
        skips = []
        for layer_idx, layer in enumerate(self.sto_emb_net):
            out = layer(out)
            skips.append(out)

        return skips

    def prior(self, x, ctx, q_dists, use_mean=False):
        dists = []
        sto_branches = sorted(self.sto_branches.keys(), reverse=True)

        for layer_idx in sto_branches:

            # print(layer_idx)

            # Find the corresopnding activations
            out = x[layer_idx][:, self.n_ctx - 1: -1].contiguous()
            cur_ctx = ctx[layer_idx][:, :self.n_ctx].contiguous()
            branch_layers = self.prior_branches['layer_{}'.format(layer_idx)]

            # Process the current branch
            for branch_layer_idx, layer in enumerate(branch_layers):

                # print(branch_layer_idx)

                if isinstance(layer, layers.ConvLSTM):
                    # Get initial condition
                    cur_ctx = cur_ctx.view(cur_ctx.shape[0], -1, cur_ctx.shape[-2], cur_ctx.shape[-1])
                    cur_ctx = cur_ctx.unsqueeze(1)
                    cur_ctx = self.prior_init_nets['layer_{}'.format(layer_idx)](cur_ctx)
                    cur_ctx = cur_ctx.squeeze(1)

                    # Forward LSTM
                    out = layer(out, torch.chunk(cur_ctx, 2, 1))

                else:
                    out = layer(out)

            # Compute distribution stats
            mean, var = torch.chunk(out, 2, 2)

            # Softplus var
            logvar = F.softplus(var).log()

            # Generate sample from this distribution
            z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)

            dists.append([mean, logvar, z0, z0, None])

        return dists


    def posterior(self, x, ctx, use_mean=False):
        dists = []
        sto_branches = sorted(self.sto_branches.keys(), reverse=True)

        for layer_idx in sto_branches:

            # print(layer_idx)

            # Find the corresopnding activations
            out = x[layer_idx][:, self.n_ctx:].contiguous()
            cur_ctx = ctx[layer_idx][:, :self.n_ctx].contiguous()
            branch_layers = self.posterior_branches['layer_{}'.format(layer_idx)]

            # Process the current branch
            for branch_layer_idx, layer in enumerate(branch_layers):

                # print(branch_layer_idx)

                if isinstance(layer, layers.ConvLSTM):
                    # Get initial condition
                    cur_ctx = cur_ctx.view(cur_ctx.shape[0], -1, cur_ctx.shape[-2], cur_ctx.shape[-1])
                    cur_ctx = cur_ctx.unsqueeze(1)
                    cur_ctx = self.posterior_init_nets['layer_{}'.format(layer_idx)](cur_ctx)
                    cur_ctx = cur_ctx.squeeze(1)

                    # Forward LSTM
                    out = layer(out, torch.chunk(cur_ctx, 2, 1))

                else:
                    out = layer(out)

            # Compute distribution stats
            mean, var = torch.chunk(out, 2, 2)

            # Softplus var
            logvar = F.softplus(var).log()

            # Generate sample from this distribution
            z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)

            dists.append([mean, logvar, z0, z0, None])

        return dists

    def render(self, x, zs, ctx):
        b, t, c, h, w = x.shape

        out = x
        for layer_idx, layer in enumerate(self.render_net):

            # print(layer_idx, '->', out.shape)

            if layer_idx in self.rend_sto_branches:
                cur_zs = zs[self.rend_sto_branches[layer_idx]]
                out = torch.cat([out, cur_zs], 2)

            if isinstance(layer, layers.ConvLSTM):
                conn_layer_idx = self.det_init_connections[layer_idx]
                cur_skip = ctx[conn_layer_idx][:, :self.n_ctx].contiguous()
                skip_layer = self.det_init_nets['layer_{}'.format(conn_layer_idx)]
                cur_skip = cur_skip.view(b, -1, cur_skip.shape[-2], cur_skip.shape[-1]).unsqueeze(1)
                cur_skip = skip_layer(cur_skip).squeeze(1)

                out = layer(out, torch.chunk(cur_skip, 2, 1))

            else:
                out = layer(out)

        out = torch.sigmoid(out)
        return out

    def forward(self, frames, config, use_prior, use_mean=False):

        stored_vars = []
        n_steps = config['n_steps']
        n_ctx = config['n_ctx']

        # Encode frames for latents and renderer
        sto_emb = self.sto_emb(frames)
        det_emb = self.det_emb(frames)

        # Get prior and posterior
        q_dists = self.posterior(sto_emb, sto_emb, use_mean=use_mean)
        p_dists = self.prior(sto_emb, sto_emb, q_dists, use_mean=use_mean)

        # Latent samples
        zs = []
        if use_prior:
            for (_, _, z0, _, _) in p_dists:
                zs.append(z0)
        else:
            for (_, _, _, zk, _) in q_dists:
                zs.append(zk)

        # Render frames
        preds = self.render(det_emb[-1][:, n_ctx - 1:-1], zs, det_emb)

        return (preds, p_dists, q_dists), stored_vars


        
if __name__ == '__main__':
    pass
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



models/s2s_convlstm_baseline.py [162:320]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return skips

    def sto_emb(self, x):
        out = x
        skips = []
        for layer_idx, layer in enumerate(self.sto_emb_net):
            out = layer(out)
            skips.append(out)

        return skips

    def prior(self, x, ctx, q_dists, use_mean=False):
        dists = []
        sto_branches = sorted(self.sto_branches.keys(), reverse=True)

        for layer_idx in sto_branches:

            # print(layer_idx)

            # Find the corresopnding activations
            out = x[layer_idx][:, self.n_ctx - 1: -1].contiguous()
            cur_ctx = ctx[layer_idx][:, :self.n_ctx].contiguous()
            branch_layers = self.prior_branches['layer_{}'.format(layer_idx)]

            # Process the current branch
            for branch_layer_idx, layer in enumerate(branch_layers):

                # print(branch_layer_idx)

                if isinstance(layer, layers.ConvLSTM):
                    # Get initial condition
                    cur_ctx = cur_ctx.view(cur_ctx.shape[0], -1, cur_ctx.shape[-2], cur_ctx.shape[-1])
                    cur_ctx = cur_ctx.unsqueeze(1)
                    cur_ctx = self.prior_init_nets['layer_{}'.format(layer_idx)](cur_ctx)
                    cur_ctx = cur_ctx.squeeze(1)

                    # Forward LSTM
                    out = layer(out, torch.chunk(cur_ctx, 2, 1))

                else:
                    out = layer(out)

            # Compute distribution stats
            mean, var = torch.chunk(out, 2, 2)

            # Softplus var
            logvar = F.softplus(var).log()

            # Generate sample from this distribution
            z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)

            dists.append([mean, logvar, z0, z0, None])

        return dists


    def posterior(self, x, ctx, use_mean=False):
        dists = []
        sto_branches = sorted(self.sto_branches.keys(), reverse=True)

        for layer_idx in sto_branches:

            # print(layer_idx)

            # Find the corresopnding activations
            out = x[layer_idx][:, self.n_ctx:].contiguous()
            cur_ctx = ctx[layer_idx][:, :self.n_ctx].contiguous()
            branch_layers = self.posterior_branches['layer_{}'.format(layer_idx)]

            # Process the current branch
            for branch_layer_idx, layer in enumerate(branch_layers):

                # print(branch_layer_idx)

                if isinstance(layer, layers.ConvLSTM):
                    # Get initial condition
                    cur_ctx = cur_ctx.view(cur_ctx.shape[0], -1, cur_ctx.shape[-2], cur_ctx.shape[-1])
                    cur_ctx = cur_ctx.unsqueeze(1)
                    cur_ctx = self.posterior_init_nets['layer_{}'.format(layer_idx)](cur_ctx)
                    cur_ctx = cur_ctx.squeeze(1)

                    # Forward LSTM
                    out = layer(out, torch.chunk(cur_ctx, 2, 1))

                else:
                    out = layer(out)

            # Compute distribution stats
            mean, var = torch.chunk(out, 2, 2)

            # Softplus var
            logvar = F.softplus(var).log()

            # Generate sample from this distribution
            z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)

            dists.append([mean, logvar, z0, z0, None])

        return dists

    def render(self, x, zs, ctx):
        b, t, c, h, w = x.shape

        out = x
        for layer_idx, layer in enumerate(self.render_net):

            # print(layer_idx, '->', out.shape)

            if layer_idx in self.rend_sto_branches:
                cur_zs = zs[self.rend_sto_branches[layer_idx]]
                out = torch.cat([out, cur_zs], 2)

            if isinstance(layer, layers.ConvLSTM):
                conn_layer_idx = self.det_init_connections[layer_idx]
                cur_skip = ctx[conn_layer_idx][:, :self.n_ctx].contiguous()
                skip_layer = self.det_init_nets['layer_{}'.format(conn_layer_idx)]
                cur_skip = cur_skip.view(b, -1, cur_skip.shape[-2], cur_skip.shape[-1]).unsqueeze(1)
                cur_skip = skip_layer(cur_skip).squeeze(1)

                out = layer(out, torch.chunk(cur_skip, 2, 1))

            else:
                out = layer(out)

        out = torch.sigmoid(out)
        return out

    def forward(self, frames, config, use_prior, use_mean=False):

        stored_vars = []
        n_steps = config['n_steps']
        n_ctx = config['n_ctx']

        # Encode frames for latents and renderer
        sto_emb = self.sto_emb(frames)
        det_emb = self.det_emb(frames)

        # Get prior and posterior
        q_dists = self.posterior(sto_emb, sto_emb, use_mean=use_mean)
        p_dists = self.prior(sto_emb, sto_emb, q_dists, use_mean=use_mean)

        # Latent samples
        zs = []
        if use_prior:
            for (_, _, z0, _, _) in p_dists:
                zs.append(z0)
        else:
            for (_, _, _, zk, _) in q_dists:
                zs.append(zk)

        # Render frames
        preds = self.render(det_emb[-1][:, n_ctx - 1:-1], zs, det_emb)

        return (preds, p_dists, q_dists), stored_vars


        
if __name__ == '__main__':
    pass
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



