def __init__()

in codes/models.py [0:0]


    def __init__(self, hparams, logbook=None):
        super(TransitionPredictorMaxPool, self).__init__(hparams, logbook)
        self.hparams = hparams
        self.copy_hparams = copy.deepcopy(self.hparams)
        self.train_data = None
        self.test_data = None
        self.down_model = None
        self.bert_input = True
        self.is_transition_fn = True
        dim = 0
        if "base-uncased" in hparams.bert_model:
            dim = 768
        else:
            raise NotImplementedError()
        with torch.no_grad():
            self.bert = self.init_bert_model()
        if hparams.downsample:
            if hparams.learn_down:
                self.down = nn.Parameter(torch.randn(dim, hparams.down_dim))
                nn.init.xavier_normal_(self.down)
                if hparams.fix_down:
                    self.down.requires_grad = False
            dim = hparams.down_dim
        self.lstm = nn.LSTM(
            dim,
            dim,
            2,
            bidirectional=hparams.bidirectional,
            batch_first=True,
            dropout=0,
        )
        if hparams.bidirectional:
            self.W = nn.Parameter(torch.randn(dim * 2, dim))
        else:
            self.W = nn.Parameter(torch.randn(dim, dim))
        # self.mpca = nn.Parameter(torch.randn(768, dim))
        self.decoder = nn.Sequential(
            nn.Linear(dim * 4, hparams.decoder_hidden),
            nn.ReLU(),
            nn.Dropout(hparams.dropout),
            nn.Linear(hparams.decoder_hidden, 1),
        )
        self.context_discriminator = nn.Sequential(
            nn.Linear(dim, hparams.decoder_hidden),
            nn.ReLU(),
            nn.Linear(hparams.decoder_hidden, 1),
        )
        nn.init.xavier_uniform_(self.W)
        self.init_bert_model()
        self.sigmoid = nn.Sigmoid()
        if self.hparams.train_mode == "nce":
            self.loss_fn = nn.BCELoss()
            self.collate_fn = id_collate_nce_fn
            if self.hparams.corrupt_type == "all_context":
                self.collate_fn = context_collate_nce_fn
        else:
            self.loss_fn = nn.MSELoss()
            self.collate_fn = id_collate_fn
        self.minibatch_step = 0