in models/vrnn_hier.py [0:0]
def __init__(self, img_ch, n_ctx,
n_hid=64,
n_z=10,
enc_dim=512,
share_prior_enc=False,
reverse_post=False,
):
super().__init__()
self.n_ctx = n_ctx
self.enc_dim = enc_dim
# Get VRNN architecture
self.arch = self.vrnn_arch(n_hid, n_z, enc_dim, img_ch)
### Define frame embedding network
emb_net = []
for i, _ in enumerate(self.arch['frame_emb']['in_ch']):
arch = self.arch['frame_emb']
inc = arch['in_ch'][i]
outc = arch['out_ch'][i]
pksize = arch['pool_ksize'][i]
pstride = arch['pool_stride'][i]
first_conv = arch['first_conv'][i]
block = []
if pksize is not None:
block += [nn.MaxPool2d(pksize, pstride)]
if first_conv:
block += [nn.Conv2d(inc, outc, 1)]
else:
block += [ResnetBlock(inc, outc)]
block += [ResnetBlock(outc, outc)]
block = nn.Sequential(*block)
emb_net += [block]
self.emb_net = nn.ModuleList(emb_net)
### Define rendering network
render_nets = []
init_nets = []
for i, _ in enumerate(self.arch['renderer']['in_ch']):
arch = self.arch['renderer']
inc = arch['in_ch'][i]
hidc = arch['hid_ch'][i]
outc = arch['out_ch'][i]
ksize = arch['ksize'][i]
padding = arch['padding'][i]
stride = arch['stride'][i]
upsample = arch['upsample'][i]
init_inc = self.arch['frame_emb']['out_ch'][-(i + 1)]
init_outc = hidc
latent_idx = arch['latent_idx'][i]
# Recompute ConvLSTM to have all the previous latents
if latent_idx is not None:
# import pdb; pdb.set_trace()
# latent_ch = self.arch['latentl']['in_ch']
latent_ch = self.arch['latent']['out_ch'][latent_idx]
inc += latent_ch
render_net = [ConvLSTM(inc, hidc, norm=True)]
if upsample:
render_net += [DcUpConv(hidc, outc, ksize, stride, padding)]
else:
render_net += [DcConv(hidc, outc, 3, 1, 1)]
# Last layer of renderer
if i == (len(arch['in_ch']) - 1):
render_net += [TemporalConv2d(n_hid, img_ch, 3, 1, 1)]
init_net = [
DcConv(init_inc*self.n_ctx, init_inc*self.n_ctx, 1),
TemporalConv2d(init_inc*self.n_ctx, init_outc*2, 1),
TemporalNorm2d(1, init_outc*2),
]
render_net = nn.ModuleList(render_net)
render_nets.append(render_net)
init_net = nn.Sequential(*init_net)
init_nets.append(init_net)
self.render_nets = nn.ModuleList(render_nets)
self.init_nets = nn.ModuleList(init_nets)
### Define latent Net
prior_init_nets = []
posterior_init_nets = []
prior_nets = []
posterior_nets = []
for i, _ in enumerate(self.arch['latent']['in_ch']):
arch = self.arch['latent']
inc = arch['in_ch'][i]
hidc = arch['hid_ch'][i]
outc = arch['out_ch'][i]
# Compute previous channels
prevc = sum(arch['out_ch'][:i])
prior_net = []
posterior_net = []
prior_init_net = []
posterior_init_net = []
prior_net += [
TemporalConv2d(inc, hidc, 1),
TemporalNorm2d(1, hidc),
ConvLSTM(hidc, hidc, norm=True),
TemporalConv2d(hidc, outc*2, 1),
TemporalNorm2d(1, outc*2),
]
posterior_net += [
TemporalConv2d(inc, hidc, 1),
TemporalNorm2d(1, hidc),
ConvLSTM(hidc, hidc, norm=True),
TemporalConv2d(hidc + prevc, outc*2, 1),
TemporalNorm2d(1, outc*2),
]
prior_init_net += [
DcConv(inc*self.n_ctx, inc*self.n_ctx, 1),
TemporalConv2d(inc*self.n_ctx, hidc*2, 1),
TemporalNorm2d(1, 2*hidc),
]
posterior_init_net += [
DcConv(inc*self.n_ctx, inc*self.n_ctx, 1),
TemporalConv2d(inc*self.n_ctx, hidc*2, 1),
TemporalNorm2d(1, 2*hidc),
]
# Make modulelist
prior_net = nn.ModuleList(prior_net)
posterior_net = nn.ModuleList(posterior_net)
prior_init_net = nn.Sequential(*prior_init_net)
posterior_init_net = nn.Sequential(*posterior_init_net)
# Append to the list of nets
prior_nets.append(prior_net)
posterior_nets.append(posterior_net)
prior_init_nets.append(prior_init_net)
posterior_init_nets.append(posterior_init_net)
# Make module list
self.prior_nets = nn.ModuleList(prior_nets)
self.posterior_nets = nn.ModuleList(posterior_nets)
self.prior_init_nets = nn.ModuleList(prior_init_nets)
self.posterior_init_nets = nn.ModuleList(posterior_init_nets)
# Init weights of last layers
for block in chain(self.prior_nets, self.posterior_nets):
nn.init.constant_(block[-1].model.weight, 0)
nn.init.normal_(block[-1].model.bias, std=1e-3)