in models/s2s_big_hier_v4.py [0:0]
def __init__(self, img_ch, n_ctx,
n_hid=64,
n_z=10,
enc_dim=512,
share_prior_enc=False,
reverse_post=False,
):
super().__init__()
self.n_ctx = n_ctx
self.enc_dim = enc_dim
self.emb_net = nn.ModuleList([
nn.Conv2d(img_ch, n_hid, 1),
ResnetBlock(n_hid, n_hid),
nn.MaxPool2d(2, 2),
ResnetBlock(n_hid*1, n_hid*2),
ResnetBlock(n_hid*2, n_hid*2),
nn.MaxPool2d(2, 2),
ResnetBlock(n_hid*2, n_hid*4),
ResnetBlock(n_hid*4, n_hid*4),
nn.MaxPool2d(2, 2),
ResnetBlock(n_hid*4, n_hid*4),
ResnetBlock(n_hid*4, n_hid*8),
nn.MaxPool2d(2, 2),
ResnetBlock(n_hid*8, n_hid*8),
ResnetBlock(n_hid*8, n_hid*8),
nn.MaxPool2d(4, 1),
ResnetBlock(n_hid*8, n_hid*8, norm_ch=1),
ResnetBlock(n_hid*8, n_hid*8, norm_ch=1),
])
mult = 1
self.render_net = nn.ModuleList([
layers.ConvLSTM(n_hid*8 + n_hid*8, n_hid*8),
layers.DcUpConv(n_hid*8, n_hid*8, 4, 1, 0),
layers.ConvLSTM(n_hid*8, n_hid*8, norm=True),
layers.DcUpConv(n_hid*8*mult, n_hid*8, 4, 2, 1),
layers.ConvLSTM(n_hid*8 + n_hid*8, n_hid*8, norm=True),
layers.DcUpConv(n_hid*8*mult, n_hid*4, 4, 2, 1),
layers.ConvLSTM(n_hid*4, n_hid*4, norm=True),
layers.DcUpConv(n_hid*4*mult, n_hid*2, 4, 2, 1),
layers.ConvLSTM(n_hid*2 + n_hid*2, n_hid*2, norm=True),
layers.DcUpConv(n_hid*2*mult, n_hid, 4, 2, 1),
layers.ConvLSTM(n_hid, n_hid, norm=True),
layers.DcConv(n_hid, n_hid, 3, 1, 1),
layers.TemporalConv2d(n_hid, img_ch, 3, 1, 1),
])
self.det_init_net = nn.Sequential(
layers.DcConv(2*n_hid*8*self.n_ctx, 2*n_hid*8*self.n_ctx, 1),
layers.TemporalConv2d(2*n_hid*8*self.n_ctx, 2*n_hid*8, 1),
layers.TemporalNorm2d(1, 2*enc_dim),
)
self.prior_init_nets = nn.ModuleDict({
'layer_16': nn.Sequential(
layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
layers.TemporalNorm2d(1, 2*n_hid*8),
),
'layer_10': nn.Sequential(
layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
layers.TemporalNorm2d(16, 2*n_hid*8),
),
'layer_4': nn.Sequential(
layers.DcConv(n_hid*2*self.n_ctx, n_hid*2*self.n_ctx, 1),
layers.TemporalConv2d(n_hid*2*self.n_ctx, n_hid*2*2, 1),
layers.TemporalNorm2d(16, 2*n_hid*2),
),
})
self.posterior_init_nets = nn.ModuleDict({
'layer_16': nn.Sequential(
layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
layers.TemporalNorm2d(1, 2*n_hid*8),
),
'layer_10': nn.Sequential(
layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
layers.TemporalNorm2d(16, 2*n_hid*8),
),
'layer_4': nn.Sequential(
layers.DcConv(n_hid*2*self.n_ctx, n_hid*2*self.n_ctx, 1),
layers.TemporalConv2d(n_hid*2*self.n_ctx, n_hid*2*2, 1),
layers.TemporalNorm2d(16, 2*n_hid*2),
),
})
self.posterior_branches = nn.ModuleDict({
'layer_4': nn.ModuleList([
layers.TemporalConv2d(n_hid*2, n_hid*2, 1),
layers.TemporalNorm2d(16, n_hid*2),
layers.ConvLSTM(n_hid*2, n_hid*2, norm=True),
layers.TemporalConv2d(n_hid*2 + n_hid*8 + n_hid*8, n_hid*2*2, 1),
layers.TemporalNorm2d(16, n_hid*2*2),
]),
'layer_10': nn.ModuleList([
layers.TemporalConv2d(n_hid*8, n_hid*8, 1),
layers.TemporalNorm2d(16, n_hid*8),
layers.ConvLSTM(n_hid*8, n_hid*8, norm=True),
layers.TemporalConv2d(n_hid*8 + n_hid*8, n_hid*8*2, 1),
layers.TemporalNorm2d(16, n_hid*8*2),
]),
'layer_16': nn.ModuleList([
layers.TemporalConv2d(n_hid*8, n_hid*8, 1),
layers.TemporalNorm2d(1, n_hid*8),
layers.ConvLSTM(n_hid*8, n_hid*8, norm=True),
layers.TemporalConv2d(n_hid*8, n_hid*8*2, 1),
layers.TemporalNorm2d(1, n_hid*8*2),
]),
})
self.prior_branches = nn.ModuleDict({
'layer_4': nn.ModuleList([
layers.TemporalConv2d(n_hid*2, n_hid*2, 1),
layers.TemporalNorm2d(16, n_hid*2),
layers.ConvLSTM(n_hid*2, n_hid*2, norm=True),
layers.TemporalConv2d(n_hid*2 + n_hid*8 + n_hid*8, n_hid*2*2, 1),
layers.TemporalNorm2d(16, n_hid*2*2),
]),
'layer_10': nn.ModuleList([
layers.TemporalConv2d(n_hid*8, n_hid*8, 1),
layers.TemporalNorm2d(16, n_hid*8),
layers.ConvLSTM(n_hid*8, n_hid*8, norm=True),
layers.TemporalConv2d(n_hid*8 + n_hid*8, n_hid*8*2, 1),
layers.TemporalNorm2d(16, n_hid*8*2),
]),
'layer_16': nn.ModuleList([
layers.TemporalConv2d(n_hid*8, n_hid*8, 1),
layers.TemporalNorm2d(1, n_hid*8),
layers.ConvLSTM(n_hid*8, n_hid*8, norm=True),
layers.TemporalConv2d(n_hid*8, n_hid*8*2, 1),
layers.TemporalNorm2d(1, n_hid*8*2),
]),
})
# Prior/Posterior branches norm init
nn.init.constant_(self.posterior_branches['layer_4'][-1].model.weight, 0)
nn.init.normal_(self.posterior_branches['layer_4'][-1].model.bias, std=1e-3)
nn.init.constant_(self.posterior_branches['layer_10'][-1].model.weight, 0)
nn.init.normal_(self.posterior_branches['layer_10'][-1].model.bias, std=1e-3)
nn.init.constant_(self.posterior_branches['layer_16'][-1].model.weight, 0)
nn.init.normal_(self.posterior_branches['layer_16'][-1].model.bias, std=1e-3)
nn.init.constant_(self.prior_branches['layer_4'][-1].model.weight, 0)
nn.init.normal_(self.prior_branches['layer_4'][-1].model.bias, std=1e-3)
nn.init.constant_(self.prior_branches['layer_10'][-1].model.weight, 0)
nn.init.normal_(self.prior_branches['layer_10'][-1].model.bias, std=1e-3)
nn.init.constant_(self.prior_branches['layer_16'][-1].model.weight, 0)
nn.init.normal_(self.prior_branches['layer_16'][-1].model.bias, std=1e-3)
# Connection list
self.det_init_connections = {
0: 16,
2: 13,
4: 10,
6: 7,
8: 4,
10: 1,
}
# Connection branches
self.det_init_nets = nn.ModuleDict({
'layer_16': nn.Sequential(
layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
layers.TemporalNorm2d(1, n_hid*8*2)
),
'layer_13': nn.Sequential(
layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 3, 1, 1),
layers.TemporalConv2d(self.n_ctx*n_hid*8, 2*n_hid*8, 1),
layers.TemporalNorm2d(16, n_hid*8*2)
),
'layer_10': nn.Sequential(
layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
layers.TemporalNorm2d(16, n_hid*8*2)
),
'layer_7': nn.Sequential(
layers.DcConv(n_hid*4*self.n_ctx, n_hid*4*self.n_ctx, 1),
layers.TemporalConv2d(n_hid*4*self.n_ctx, n_hid*4*2, 1),
layers.TemporalNorm2d(16, n_hid*8)
),
'layer_4': nn.Sequential(
layers.DcConv(n_hid*2*self.n_ctx, n_hid*2*self.n_ctx, 1),
layers.TemporalConv2d(n_hid*2*self.n_ctx, n_hid*2*2, 1),
layers.TemporalNorm2d(16, n_hid*4)
),
'layer_1': nn.Sequential(
layers.DcConv(n_hid*1*self.n_ctx, n_hid*1*self.n_ctx, 1),
layers.TemporalConv2d(n_hid*1*self.n_ctx, n_hid*1*2, 1),
layers.TemporalNorm2d(16, n_hid*2)
),
})
# Stochastic connection list
# encoder -> renderer
self.sto_branches = {
16: 0,
10: 4,
4: 8,
}
# renderer -> encoder
self.rend_sto_branches = {
0: 0,
4: 1,
8: 2,
}