in NMT/src/model/lm.py [0:0]
def __init__(self, params, model, is_encoder, reverse):
"""
Language model initialization.
"""
super(SubLM, self).__init__()
assert type(is_encoder) is bool and type(reverse) is bool
assert reverse is False or is_encoder and params.attention
# model parameters
self.n_langs = params.n_langs
self.n_words = params.n_words
self.emb_dim = params.emb_dim
self.hidden_dim = params.hidden_dim
self.dropout = params.dropout
self.pad_index = params.pad_index
self.is_enc_lm = is_encoder
s_name = "encoder" if is_encoder else "decoder"
assert 0 <= params.lm_share_enc <= params.n_enc_layers
assert 0 <= params.lm_share_dec <= params.n_dec_layers
# embedding layers
if params.lm_share_emb:
embeddings = model.embeddings
logger.info("Sharing language model input embeddings with the %s" % s_name)
else:
embeddings = []
for n_words in self.n_words:
layer_i = nn.Embedding(n_words, self.emb_dim, padding_idx=self.pad_index)
nn.init.normal_(layer_i.weight, 0, 0.1)
nn.init.constant_(layer_i.weight[self.pad_index], 0)
embeddings.append(layer_i)
embeddings = nn.ModuleList(embeddings)
self.embeddings = embeddings
# LSTM layers / shared layers
n_rec_share = params.lm_share_enc if is_encoder else params.lm_share_dec
lstm = [
nn.LSTM(self.emb_dim, self.hidden_dim, num_layers=max(n_rec_share, 1), dropout=self.dropout)
for _ in range(self.n_langs)
]
for k in range(n_rec_share):
logger.info("Sharing language model LSTM parameters for layer %i with the %s" % (k, s_name))
for i in range(self.n_langs):
for name in LSTM_PARAMS:
if is_encoder or not params.attention:
_name = name if reverse is False else ('%s_reverse' % name)
setattr(lstm[i], name % k, getattr(model.lstm[i], _name % k))
elif k == 0:
setattr(lstm[i], name % k, getattr(model.lstm1[i], name % k))
else:
setattr(lstm[i], name % k, getattr(model.lstm2[i], name % (k - 1)))
self.lstm = nn.ModuleList(lstm)
# projection layers
if params.lm_share_proj and not is_encoder:
logger.info("Sharing language model projection layer with the decoder")
proj = model.proj
else:
proj = nn.ModuleList([nn.Linear(self.hidden_dim, n_words) for n_words in self.n_words])
self.proj = proj