in xlm/model/transformer.py [0:0]
def __init__(self, params, dico, is_encoder, with_output):
"""
Transformer model (encoder or decoder).
"""
super().__init__()
# encoder / decoder, output layer
self.is_encoder = is_encoder
self.is_decoder = not is_encoder
self.with_output = with_output
# dictionary / languages
self.n_langs = params.n_langs
self.n_words = params.n_words
self.eos_index = params.eos_index
self.pad_index = params.pad_index
self.dico = dico
self.id2lang = params.id2lang
self.lang2id = params.lang2id
self.use_lang_emb = getattr(params, 'use_lang_emb', True)
assert len(self.dico) == self.n_words
assert len(self.id2lang) == len(self.lang2id) == self.n_langs
# model parameters
self.dim = params.emb_dim # 512 by default
self.hidden_dim = self.dim * 4 # 2048 by default
self.n_heads = params.n_heads # 8 by default
self.n_layers = params.n_layers
self.dropout = params.dropout
self.attention_dropout = params.attention_dropout
assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'
# embeddings
self.position_embeddings = Embedding(N_MAX_POSITIONS, self.dim)
if params.sinusoidal_embeddings:
create_sinusoidal_embeddings(N_MAX_POSITIONS, self.dim, out=self.position_embeddings.weight)
if params.n_langs > 1 and self.use_lang_emb:
self.lang_embeddings = Embedding(self.n_langs, self.dim)
self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12)
# transformer layers
self.attentions = nn.ModuleList()
self.layer_norm1 = nn.ModuleList()
self.ffns = nn.ModuleList()
self.layer_norm2 = nn.ModuleList()
if self.is_decoder:
self.layer_norm15 = nn.ModuleList()
self.encoder_attn = nn.ModuleList()
# memories
self.memories = nn.ModuleDict()
if getattr(params, 'use_memory', False):
mem_positions = params.mem_enc_positions if is_encoder else params.mem_dec_positions
for layer_id, pos in mem_positions:
assert 0 <= layer_id <= params.n_layers - 1
assert pos in ['in', 'after']
self.memories['%i_%s' % (layer_id, pos)] = HashingMemory.build(self.dim, self.dim, params)
for layer_id in range(self.n_layers):
self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12))
if self.is_decoder:
self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12))
self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
if ('%i_in' % layer_id) in self.memories:
self.ffns.append(None)
else:
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=params.gelu_activation))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))
# output layer
if self.with_output:
self.pred_layer = PredLayer(params)
if params.share_inout_emb:
self.pred_layer.proj.weight = self.embeddings.weight