def __init__()

in conv_lstm_models.py [0:0]


    def __init__(self, args,
                 bypass_encoder=False,
                 enc_convsize=3,
                 enc_embsize=256,
                 enc_depth=3,
                 inp_embsize=256,
                 top_pooling='mean',
                 with_z=False,
                 z_opt=th.optim.SGD,
                 z_lr=0.01,
                 z_lambda=100,
                 z_pred_cut_gradient=False, # Whether to push gradients from the Z loss
                 z_after_lstm=False,  # Whether to do the z model after the LSTM
                 zbwd_init_zfwd=True,  # Whether to init zbwd with zfwd (zpred)
                 zbwd_to_convergence=True,  # Whether to optimize zbwdto convergence of the final loss for each game
                 zbwd_single=False,  # Whether to have only zbwd(game) instead of zbwd(time,game)
                 zfwd_zbwd_ratio=0,  # Ratio of how much of zfwd / (zfwd + zbwd) to put as input to the decoder:
                                     # 0 means only zbwd, 1 means only zfwd, 0.5 means half each.
                 **kwargs):
        kwargs.setdefault('model_name', 'simple')

        self.z_opt = z_opt
        self.enc_embsize = enc_embsize
        self.with_z = with_z
        self.z_after_lstm = z_after_lstm
        rnn_input_size = (enc_embsize * 2 if bypass_encoder else enc_embsize)
        if self.with_z:  # TODO replace by class decorator
            self.zsize = 64
            if not self.z_after_lstm:
                rnn_input_size += self.zsize
            logging.info("rnn input size: {}".format(rnn_input_size))

        super(simple, self).__init__(
            args, lstm_nlayers=1,
            rnn_input_size=rnn_input_size, **kwargs
        )
        assert (self.dec_convsize % 2) == 1, \
            "ERROR: the size of the decoder convolution is not odd"

        self.bypass_encoder = bypass_encoder
        self.append_to_decoder_input = []
        self.predict_delta = args.predict_delta
        self.top_pooling = top_pooling
        self.inp_embsize = inp_embsize

        # Overrides
        self.conv1x1 = nn.Conv2d(self.nchannel, self.inp_embsize, 1)  # TODO do that before trunk?
        if self.residual:
            assert self.inp_embsize == self.enc_embsize, "can't residual from {} to {}".format(self.inp_embsize, self.enc_embsize)
        self.encoder = convnet(
            convsize_0 =3,
            convsize   =5,
            padding_0  =1,
            padding    =2,
            conv       =self.convmod,
            non_lin    =self.nonlin,
            input_size =self.inp_embsize,
            interm_size=self.inp_embsize,
            output_size=self.enc_embsize,
            depth      =2,
            stride_0   =1,
            stride     =2
        )
        self.z_pred_cut_gradient = z_pred_cut_gradient
        if self.with_z:  # TODO replace by class decorator
            self.game_name = None
            zlinear = None
            if self.z_after_lstm:
                zlinear = nn.Linear(self.hid_dim, self.zsize)
            else:
                zlinear = nn.Linear(self.enc_embsize + (0 if not self.bypass_encoder else self.enc_embsize), self.zsize)
            self.zpred = zlinear
            self.zbwd = Variable(th.zeros(1,1,self.zsize).type(th.cuda.FloatTensor))
            self.zbwd.requires_grad = True
            self.zs = {}  # TODO replace by LookUpTable
            self.zlossfn = nn.MSELoss(size_average=True)
            # ^ could also change loss, and make sure the z_lr is small enough!
            self.z_lr = z_lr
            self.z_lambda = z_lambda
            self.zbwd_init_zfwd = zbwd_init_zfwd
            self.zbwd_to_convergence = zbwd_to_convergence
            self.zbwd_single = zbwd_single
            if self.zbwd_single:
                assert self.zbwd_init_zfwd
                assert self.zbwd_to_convergence
            self.zfwd_zbwd_ratio = zfwd_zbwd_ratio
            if self.zfwd_zbwd_ratio > 0:
                assert self.zbwd_init_zfwd

        # TODO decoder that starts from input embedding (after first 1x1 Conv2d)
        # TODO check input/output size in features/channels
        # TODO try to remove border artifacts (borders are important!)
        # TODO hierarchical deconv
        self.decoder = decoder(self.dec_convsize, self.dec_convsize)(
            conv        =self.convmod,
            non_lin     =self.nonlin,
            input_size  =self.nchannel + self.hid_dim + (self.zsize if self.z_after_lstm else 0),
            interm_size =self.dec_embsize,
            output_size =self.dec_embsize,
            depth       =self.dec_depth,
        )

        # Modules
        if self.bypass_encoder:
            self.sum_pool_embed = nn.Linear(self.nfeat, self.enc_embsize)
        if self.top_pooling == 'all':
            self.weight_poolings = nn.Linear(self.enc_embsize, self.enc_embsize * 2)