in NMT/src/model/seq2seq.py [0:0]
def __init__(self, params, encoder):
"""
Decoder initialization.
"""
super(Decoder, self).__init__()
# model parameters
self.n_langs = params.n_langs
self.n_words = params.n_words
self.share_lang_emb = params.share_lang_emb
self.share_encdec_emb = params.share_encdec_emb
self.share_decpro_emb = params.share_decpro_emb
self.share_output_emb = params.share_output_emb
self.share_lstm_proj = params.share_lstm_proj
self.share_dec = params.share_dec
self.emb_dim = params.emb_dim
self.hidden_dim = params.hidden_dim
self.lstm_proj = params.lstm_proj
self.dropout = params.dropout
self.n_dec_layers = params.n_dec_layers
self.enc_dim = params.enc_dim
self.init_encoded = params.init_encoded
self.freeze_dec_emb = params.freeze_dec_emb
assert not self.share_lang_emb or len(set(params.n_words)) == 1
assert not self.share_decpro_emb or self.lstm_proj or self.emb_dim == self.hidden_dim
assert 0 <= self.share_dec <= self.n_dec_layers
assert self.enc_dim == self.hidden_dim or not self.init_encoded
# indexes
self.eos_index = params.eos_index
self.pad_index = params.pad_index
self.bos_index = params.bos_index
# words allowed for generation
self.vocab_mask_neg = params.vocab_mask_neg if len(params.vocab) > 0 else None
# embedding layers
if self.share_encdec_emb:
logger.info("Sharing encoder and decoder input embeddings")
embeddings = encoder.embeddings
else:
if self.share_lang_emb:
logger.info("Sharing decoder input embeddings")
layer_0 = nn.Embedding(self.n_words[0], self.emb_dim, padding_idx=self.pad_index)
nn.init.normal_(layer_0.weight, 0, 0.1)
nn.init.constant_(layer_0.weight[self.pad_index], 0)
embeddings = [layer_0 for _ in range(self.n_langs)]
else:
embeddings = []
for n_words in self.n_words:
layer_i = nn.Embedding(n_words, self.emb_dim, padding_idx=self.pad_index)
nn.init.normal_(layer_i.weight, 0, 0.1)
nn.init.constant_(layer_i.weight[self.pad_index], 0)
embeddings.append(layer_i)
embeddings = nn.ModuleList(embeddings)
self.embeddings = embeddings
# LSTM layers / shared layers
input_dim = self.emb_dim + (0 if self.init_encoded else self.enc_dim)
lstm = [
nn.LSTM(input_dim, self.hidden_dim, num_layers=self.n_dec_layers, dropout=self.dropout)
for _ in range(self.n_langs)
]
for k in range(self.n_dec_layers):
if k + 1 <= self.share_dec:
logger.info("Sharing decoder LSTM parameters for layer %i" % k)
for i in range(1, self.n_langs):
for name in LSTM_PARAMS:
setattr(lstm[i], name % k, getattr(lstm[0], name % k))
self.lstm = nn.ModuleList(lstm)
# projection layers between LSTM and output embeddings
if self.lstm_proj:
lstm_proj_layers = [nn.Linear(self.hidden_dim, self.emb_dim) for _ in range(self.n_langs)]
if self.share_lstm_proj:
logger.info("Sharing decoder post-LSTM projection layers")
for i in range(1, self.n_langs):
lstm_proj_layers[i].weight = lstm_proj_layers[0].weight
lstm_proj_layers[i].bias = lstm_proj_layers[0].bias
self.lstm_proj_layers = nn.ModuleList(lstm_proj_layers)
proj_output_dim = self.emb_dim
else:
self.lstm_proj_layers = [None for _ in range(self.n_langs)]
proj_output_dim = self.hidden_dim
# projection layers
proj = [nn.Linear(proj_output_dim, n_words) for n_words in self.n_words]
if self.share_decpro_emb:
logger.info("Sharing input embeddings and projection matrix in the decoder")
for i in range(self.n_langs):
proj[i].weight = self.embeddings[i].weight
if self.share_lang_emb:
assert self.share_output_emb
logger.info("Sharing decoder projection matrices")
for i in range(1, self.n_langs):
proj[i].bias = proj[0].bias
elif self.share_output_emb:
assert self.share_lang_emb
logger.info("Sharing decoder projection matrices")
for i in range(1, self.n_langs):
proj[i].weight = proj[0].weight
proj[i].bias = proj[0].bias
self.proj = nn.ModuleList(proj)