def vrnn_arch()

in models/vrnn_hier.py [0:0]


    def vrnn_arch(self, n_hid, n_z, enc_dim, img_ch):
        """Returns a dictionary with the structure of each component of VRNN."""

        arch = {}

        # Frame Embedding Net / Encoder net
        frame_emb = {}
        frame_emb['in_ch'] = [img_ch] + [n_hid*i for i in [1, 2, 4, 8, 8]]
        frame_emb['out_ch'] = [n_hid*i for i in [1, 2, 4, 8, 8, 8]]
        frame_emb['first_conv'] = [True, False, False, False, False, False]
        frame_emb['pool_ksize'] = [None, 2, 2, 2, 2, 4]
        frame_emb['pool_stride'] = [None, 2, 2, 2, 2, 1]
        arch['frame_emb'] = frame_emb

        # Renderer/Likelihood model
        renderer = {}
        renderer['in_ch'] = [n_hid*i for i in [8, 8, 8, 4, 2, 1]]
        renderer['hid_ch'] = [n_hid*i for i in [8, 8, 8, 4, 2, 1]]
        renderer['out_ch'] = [n_hid*i for i in [8, 8, 4, 2, 1, 1]]
        renderer['ksize'] = [4, 4, 4, 4, 4, 4]
        renderer['stride'] = [1, 2, 2, 2, 2, 2]
        renderer['padding'] = [0, 1, 1, 1, 1, 1] 
        renderer['upsample'] = [True, True, True, True, True, False]
        renderer['latent_idx'] = [0, None, 1, None, 2, None]
        arch['renderer'] = renderer

        # Prior/Posterior networks
        latent = {}
        latent['in_ch'] = [n_hid*8, n_hid*8, n_hid*2]
        latent['hid_ch'] = [n_hid*8, n_hid*8, n_hid*2]
        latent['out_ch'] = [n_z, n_hid*8, n_hid*2]
        latent['ctx_idx'] = [i for i, j in enumerate(reversed(renderer['latent_idx'])) if j is not None]
        latent['ctx_idx'] = list(reversed(latent['ctx_idx']))
        latent['resolution'] = [1, 8, 32]
        arch['latent'] = latent

        return arch