in conv_lstm_models.py [0:0]
def __init__(self, args,
bypass_encoder=False,
enc_convsize=3,
enc_embsize=256,
enc_depth=3,
inp_embsize=256,
top_pooling='mean',
with_z=False,
z_opt=th.optim.SGD,
z_lr=0.01,
z_lambda=100,
z_pred_cut_gradient=False, # Whether to push gradients from the Z loss
z_after_lstm=False, # Whether to do the z model after the LSTM
zbwd_init_zfwd=True, # Whether to init zbwd with zfwd (zpred)
zbwd_to_convergence=True, # Whether to optimize zbwdto convergence of the final loss for each game
zbwd_single=False, # Whether to have only zbwd(game) instead of zbwd(time,game)
zfwd_zbwd_ratio=0, # Ratio of how much of zfwd / (zfwd + zbwd) to put as input to the decoder:
# 0 means only zbwd, 1 means only zfwd, 0.5 means half each.
**kwargs):
kwargs.setdefault('model_name', 'simple')
self.z_opt = z_opt
self.enc_embsize = enc_embsize
self.with_z = with_z
self.z_after_lstm = z_after_lstm
rnn_input_size = (enc_embsize * 2 if bypass_encoder else enc_embsize)
if self.with_z: # TODO replace by class decorator
self.zsize = 64
if not self.z_after_lstm:
rnn_input_size += self.zsize
logging.info("rnn input size: {}".format(rnn_input_size))
super(simple, self).__init__(
args, lstm_nlayers=1,
rnn_input_size=rnn_input_size, **kwargs
)
assert (self.dec_convsize % 2) == 1, \
"ERROR: the size of the decoder convolution is not odd"
self.bypass_encoder = bypass_encoder
self.append_to_decoder_input = []
self.predict_delta = args.predict_delta
self.top_pooling = top_pooling
self.inp_embsize = inp_embsize
# Overrides
self.conv1x1 = nn.Conv2d(self.nchannel, self.inp_embsize, 1) # TODO do that before trunk?
if self.residual:
assert self.inp_embsize == self.enc_embsize, "can't residual from {} to {}".format(self.inp_embsize, self.enc_embsize)
self.encoder = convnet(
convsize_0 =3,
convsize =5,
padding_0 =1,
padding =2,
conv =self.convmod,
non_lin =self.nonlin,
input_size =self.inp_embsize,
interm_size=self.inp_embsize,
output_size=self.enc_embsize,
depth =2,
stride_0 =1,
stride =2
)
self.z_pred_cut_gradient = z_pred_cut_gradient
if self.with_z: # TODO replace by class decorator
self.game_name = None
zlinear = None
if self.z_after_lstm:
zlinear = nn.Linear(self.hid_dim, self.zsize)
else:
zlinear = nn.Linear(self.enc_embsize + (0 if not self.bypass_encoder else self.enc_embsize), self.zsize)
self.zpred = zlinear
self.zbwd = Variable(th.zeros(1,1,self.zsize).type(th.cuda.FloatTensor))
self.zbwd.requires_grad = True
self.zs = {} # TODO replace by LookUpTable
self.zlossfn = nn.MSELoss(size_average=True)
# ^ could also change loss, and make sure the z_lr is small enough!
self.z_lr = z_lr
self.z_lambda = z_lambda
self.zbwd_init_zfwd = zbwd_init_zfwd
self.zbwd_to_convergence = zbwd_to_convergence
self.zbwd_single = zbwd_single
if self.zbwd_single:
assert self.zbwd_init_zfwd
assert self.zbwd_to_convergence
self.zfwd_zbwd_ratio = zfwd_zbwd_ratio
if self.zfwd_zbwd_ratio > 0:
assert self.zbwd_init_zfwd
# TODO decoder that starts from input embedding (after first 1x1 Conv2d)
# TODO check input/output size in features/channels
# TODO try to remove border artifacts (borders are important!)
# TODO hierarchical deconv
self.decoder = decoder(self.dec_convsize, self.dec_convsize)(
conv =self.convmod,
non_lin =self.nonlin,
input_size =self.nchannel + self.hid_dim + (self.zsize if self.z_after_lstm else 0),
interm_size =self.dec_embsize,
output_size =self.dec_embsize,
depth =self.dec_depth,
)
# Modules
if self.bypass_encoder:
self.sum_pool_embed = nn.Linear(self.nfeat, self.enc_embsize)
if self.top_pooling == 'all':
self.weight_poolings = nn.Linear(self.enc_embsize, self.enc_embsize * 2)