in pytorch_transformers/modeling_transfo_xl.py [0:0]
def __init__(self, config):
super(TransfoXLModel, self).__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.n_token = config.n_token
self.d_embed = config.d_embed
self.d_model = config.d_model
self.n_head = config.n_head
self.d_head = config.d_head
self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
div_val=config.div_val)
self.drop = nn.Dropout(config.dropout)
self.n_layer = config.n_layer
self.tgt_len = config.tgt_len
self.mem_len = config.mem_len
self.ext_len = config.ext_len
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
self.attn_type = config.attn_type
if not config.untie_r:
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.layers = nn.ModuleList()
if config.attn_type == 0: # the default attention
for i in range(config.n_layer):
self.layers.append(
RelPartialLearnableDecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
)
elif config.attn_type == 1: # learnable embeddings
for i in range(config.n_layer):
self.layers.append(
RelLearnableDecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
)
elif config.attn_type in [2, 3]: # absolute embeddings
for i in range(config.n_layer):
self.layers.append(
DecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
)
self.same_length = config.same_length
self.clamp_len = config.clamp_len
if self.attn_type == 0: # default attention
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 1: # learnable
self.r_emb = nn.Parameter(torch.FloatTensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.r_bias = nn.Parameter(torch.FloatTensor(
self.n_layer, self.max_klen, self.n_head))
elif self.attn_type == 2: # absolute standard
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 3: # absolute deeper SA
self.r_emb = nn.Parameter(torch.FloatTensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.apply(self.init_weights)