in NMT/src/model/transformer.py [0:0]
def __init__(self, args, encoder):
super().__init__()
self.dropout = args.dropout
self.n_langs = args.n_langs
self.n_words = args.n_words
self.share_lang_emb = args.share_lang_emb
self.share_encdec_emb = args.share_encdec_emb
self.share_decpro_emb = args.share_decpro_emb
self.share_output_emb = args.share_output_emb
self.share_dec = args.share_dec
self.freeze_dec_emb = args.freeze_dec_emb
self.encoder_class = encoder.__class__
self.beam_size = args.beam_size
self.length_penalty = args.length_penalty
# indexes
self.eos_index = args.eos_index
self.pad_index = args.pad_index
self.bos_index = args.bos_index
# words allowed for generation
self.vocab_mask_neg = args.vocab_mask_neg if len(args.vocab) > 0 else None # TODO: implement
# embedding layers
self.emb_dim = args.decoder_embed_dim
if self.share_encdec_emb:
logger.info("Sharing encoder and decoder input embeddings")
embeddings = encoder.embeddings
else:
if self.share_lang_emb:
logger.info("Sharing decoder input embeddings")
layer_0 = Embedding(self.n_words[0], self.emb_dim, padding_idx=self.pad_index)
embeddings = [layer_0 for _ in range(self.n_langs)]
else:
embeddings = [Embedding(n_words, self.emb_dim, padding_idx=self.pad_index) for n_words in self.n_words]
embeddings = nn.ModuleList(embeddings)
self.embeddings = embeddings
self.embed_scale = math.sqrt(self.emb_dim)
self.embed_positions = PositionalEmbedding(
1024, self.emb_dim, self.pad_index,
left_pad=args.left_pad_target,
)
self.layers = nn.ModuleList()
for k in range(args.decoder_layers):
# share bottom share_dec layers
layer_is_shared = (k < args.share_dec)
if layer_is_shared:
logger.info("Sharing decoder transformer parameters for layer %i" % k)
self.layers.append(nn.ModuleList([
# layer for first lang
TransformerDecoderLayer(args)
]))
for i in range(1, self.n_langs):
# layer for lang i
if layer_is_shared:
# share layer from lang 0
self.layers[k].append(self.layers[k][0])
else:
self.layers[k].append(TransformerDecoderLayer(args))
# projection layers
proj = [nn.Linear(self.emb_dim, n_words) for n_words in self.n_words]
if self.share_decpro_emb:
logger.info("Sharing input embeddings and projection matrix in the decoder")
for i in range(self.n_langs):
proj[i].weight = self.embeddings[i].weight
if self.share_lang_emb:
assert self.share_output_emb
logger.info("Sharing decoder projection matrices")
for i in range(1, self.n_langs):
proj[i].bias = proj[0].bias
elif self.share_output_emb:
assert self.share_lang_emb
logger.info("Sharing decoder projection matrices")
for i in range(1, self.n_langs):
proj[i].weight = proj[0].weight
proj[i].bias = proj[0].bias
self.proj = nn.ModuleList(proj)