models/s2s_big_hier.py [426:477]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                        z2 = z2.view(b, t, c, z2.shape[-2], z2.shape[-1])
                        out = torch.cat([out, z1, z2], 2)
                        out = layer(out)

                else:
                    out = layer(out)

            # Compute distribution stats
            mean, var = torch.chunk(out, 2, 2)

            # Softplus var
            logvar = F.softplus(var).log()

            # Generate sample from this distribution
            z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)

            dists.append([mean, logvar, z0, z0, None])

        return dists

    def render(self, x, zs, ctx):
        b, t, c, h, w = x.shape

        out = x
        for layer_idx, layer in enumerate(self.render_net):

            # print(layer_idx, '->', out.shape)

            if layer_idx in self.rend_sto_branches:
                cur_zs = zs[self.rend_sto_branches[layer_idx]]
                out = torch.cat([out, cur_zs], 2)

            if isinstance(layer, layers.ConvLSTM):
                conn_layer_idx = self.det_init_connections[layer_idx]
                cur_skip = ctx[conn_layer_idx][:, :self.n_ctx].contiguous()
                skip_layer = self.det_init_nets['layer_{}'.format(conn_layer_idx)]
                cur_skip = cur_skip.view(b, -1, cur_skip.shape[-2], cur_skip.shape[-1]).unsqueeze(1)
                cur_skip = skip_layer(cur_skip).squeeze(1)

                out = layer(out, torch.chunk(cur_skip, 2, 1))

            else:
                out = layer(out)

        out = torch.sigmoid(out)
        return out

    def forward(self, frames, config, use_prior, use_mean=False):

        stored_vars = []
        n_steps = config['n_steps']
        n_ctx = config['n_ctx']
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



models/s2s_big_hier_128.py [455:506]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                        z2 = z2.view(b, t, c, z2.shape[-2], z2.shape[-1])
                        out = torch.cat([out, z1, z2], 2)
                        out = layer(out)

                else:
                    out = layer(out)

            # Compute distribution stats
            mean, var = torch.chunk(out, 2, 2)

            # Softplus var
            logvar = F.softplus(var).log()

            # Generate sample from this distribution
            z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)

            dists.append([mean, logvar, z0, z0, None])

        return dists

    def render(self, x, zs, ctx):
        b, t, c, h, w = x.shape

        out = x
        for layer_idx, layer in enumerate(self.render_net):

            # print(layer_idx, '->', out.shape)

            if layer_idx in self.rend_sto_branches:
                cur_zs = zs[self.rend_sto_branches[layer_idx]]
                out = torch.cat([out, cur_zs], 2)

            if isinstance(layer, layers.ConvLSTM):
                conn_layer_idx = self.det_init_connections[layer_idx]
                cur_skip = ctx[conn_layer_idx][:, :self.n_ctx].contiguous()
                skip_layer = self.det_init_nets['layer_{}'.format(conn_layer_idx)]
                cur_skip = cur_skip.view(b, -1, cur_skip.shape[-2], cur_skip.shape[-1]).unsqueeze(1)
                cur_skip = skip_layer(cur_skip).squeeze(1)

                out = layer(out, torch.chunk(cur_skip, 2, 1))

            else:
                out = layer(out)

        out = torch.sigmoid(out)
        return out

    def forward(self, frames, config, use_prior, use_mean=False):

        stored_vars = []
        n_steps = config['n_steps']
        n_ctx = config['n_ctx']
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



