in pytorch-transformers/pytorch_transformers/modeling_transfo_xl.py [0:0]
def _forward(self, dec_inp, mems=None, head_mask=None):
qlen, bsz = dec_inp.size()
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.n_layer
word_emb = self.word_emb(dec_inp)
mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen
if self.same_length:
all_ones = word_emb.new_ones(qlen, klen)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
else:
mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+ torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1
else:
dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]
hids = []
attentions = []
if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb)
pos_emb = self.drop(pos_emb)
for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
layer_outputs = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb)
for i, layer in enumerate(self.layers):
hids.append(core_out)
if self.clamp_len > 0:
r_emb = self.r_emb[i][-self.clamp_len :]
r_bias = self.r_bias[i][-self.clamp_len :]
else:
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i]
layer_outputs = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb + pos_emb[-qlen:])
for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen]
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 3:
core_out = self.drop(word_emb)
for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
if mems_i is not None and mlen > 0:
cur_emb = self.r_emb[i][:-qlen]
cur_size = cur_emb.size(0)
if cur_size < mlen:
cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
else:
cur_emb = cur_emb[-mlen:]
mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, mlen, qlen)
# We transpose back here to shape [bsz, len, hidden_dim]
outputs = [core_out.transpose(0, 1).contiguous(), new_mems]
if self.output_hidden_states:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids.append(core_out)
hids = list(t.transpose(0, 1).contiguous() for t in hids)
outputs.append(hids)
if self.output_attentions:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs.append(attentions)
return outputs # last hidden state, new_mems, (all hidden states), (all attentions)