in codes/models.py [0:0]
def __init__(self, hparams, logbook=None):
super(TransitionPredictorMaxPool, self).__init__(hparams, logbook)
self.hparams = hparams
self.copy_hparams = copy.deepcopy(self.hparams)
self.train_data = None
self.test_data = None
self.down_model = None
self.bert_input = True
self.is_transition_fn = True
dim = 0
if "base-uncased" in hparams.bert_model:
dim = 768
else:
raise NotImplementedError()
with torch.no_grad():
self.bert = self.init_bert_model()
if hparams.downsample:
if hparams.learn_down:
self.down = nn.Parameter(torch.randn(dim, hparams.down_dim))
nn.init.xavier_normal_(self.down)
if hparams.fix_down:
self.down.requires_grad = False
dim = hparams.down_dim
self.lstm = nn.LSTM(
dim,
dim,
2,
bidirectional=hparams.bidirectional,
batch_first=True,
dropout=0,
)
if hparams.bidirectional:
self.W = nn.Parameter(torch.randn(dim * 2, dim))
else:
self.W = nn.Parameter(torch.randn(dim, dim))
# self.mpca = nn.Parameter(torch.randn(768, dim))
self.decoder = nn.Sequential(
nn.Linear(dim * 4, hparams.decoder_hidden),
nn.ReLU(),
nn.Dropout(hparams.dropout),
nn.Linear(hparams.decoder_hidden, 1),
)
self.context_discriminator = nn.Sequential(
nn.Linear(dim, hparams.decoder_hidden),
nn.ReLU(),
nn.Linear(hparams.decoder_hidden, 1),
)
nn.init.xavier_uniform_(self.W)
self.init_bert_model()
self.sigmoid = nn.Sigmoid()
if self.hparams.train_mode == "nce":
self.loss_fn = nn.BCELoss()
self.collate_fn = id_collate_nce_fn
if self.hparams.corrupt_type == "all_context":
self.collate_fn = context_collate_nce_fn
else:
self.loss_fn = nn.MSELoss()
self.collate_fn = id_collate_fn
self.minibatch_step = 0