in conv_lstm_models.py [0:0]
def __init__(self, args,
midconv_kw=3,
midconv_stride=2,
midconv_depth=2,
n_lvls=2,
upsample='bilinear',
use_mid_rnns_in_encoding=True,
**kwargs):
assert n_lvls > 0, "n_lvls must be at least 1"
self.n_lvls = n_lvls
kwargs.setdefault('model_name', 'multilvl_lstm')
super(multilvl_lstm, self).__init__(args, **kwargs)
midconv_padding = (midconv_kw - 1) // 2
# Overrides
self.encoder = None # TODO I hope this garbage collects.
# TODO actually do it with the nn.Module correspondong to midnets+midrnn
# TODO skip connections from midnets to corresponding decoder level.
self.decoder = decoder(self.dec_convsize, self.dec_convsize)(
conv =self.convmod,
non_lin =self.nonlin,
input_size =self.nchannel + self.hid_dim + self.enc_embsize * self.n_lvls + (self.zsize if self.z_after_lstm else 0),
interm_size =self.dec_embsize,
output_size =self.dec_embsize,
depth =self.dec_depth,
) # should be depth=1, the lstm should do work the work.
self.rnn_input_size = self.inp_embsize + (0 if not self.bypass_encoder else self.enc_embsize) + (0 if not self.with_z or self.z_after_lstm else self.zsize)
self.rnn = nn.LSTM(self.rnn_input_size, self.hid_dim, self.lstm_num_layers, dropout=self.lstm_dropout)
# Modules
self.use_mid_rnns_in_encoding = use_mid_rnns_in_encoding
self.midnets = nn.ModuleList()
self.midrnn = nn.ModuleList()
for i in range(n_lvls):
isize = self.enc_embsize
osize = self.enc_embsize
if i == 0:
isize = self.inp_embsize
self.midnets.append(nn.Sequential(
simple_convnet(
convsize =midconv_kw,
padding =midconv_padding,
conv =self.convmod,
non_lin =self.nonlin,
input_size =isize,
output_size=osize,
depth =midconv_depth - 1,
stride =1
),
self.convmod(isize, osize, 3, 2, padding=1)
))
self.midrnn.append(nn.LSTM(isize, osize, 1, dropout=self.lstm_dropout))
self.upsample = {
'bilinear': F.upsample_bilinear,
'nearest': F.upsample_nearest,
}[upsample]