in model/mem_transformer.py [0:0]
def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(
qlen, bsz, self.n_head, self.d_head
) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(
klen, bsz, self.n_head, self.d_head
) # qlen x bsz x n_head x d_head
w_head_v = w_head_v.view(
klen, bsz, self.n_head, self.d_head
) # qlen x bsz x n_head x d_head
r_head_k = r_head_k.view(
rlen, self.n_head, self.d_head
) # qlen x n_head x d_head
#### compute attention score
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum(
"ibnd,jbnd->bnij", (rw_head_q, w_head_k)
) # qlen x klen x bsz x n_head
rr_head_q = w_head_q + r_r_bias
BD = torch.einsum(
"ibnd,jnd->bnij", (rr_head_q, r_head_k)
) # qlen x klen x bsz x n_head
BD = self._rel_shift(BD)
# [bsz x n_head x qlen x klen]
attn_score = AC + BD
attn_score.mul_(self.scale)
# pdb.set_trace()
# if torch.any(attn_score == -float('inf')) :
# pdb.set_trace()
#### compute attention probability
if attn_mask is not None:
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None, None, :, :], -float("inf"))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:, None, :, :], -float("inf"))
# [bsz x n_head x qlen x klen]
attn_prob = F.softmax(attn_score, dim=3)
attn_prob = self.dropatt(attn_prob)
#### compute attention vector
# Convert all -float("inf") to 0
# if torch.any(w_head_v == -float('inf')) :
# pdb.set_trace()
# w_head_v = w_head_v.float().masked_fill(w_head_v == -float('inf'),0 ).type_as(w_head_v)
# if torch.any(w_head_v != w_head_v) :
# pdb.set_trace()
attn_vec = torch.einsum("bnij,jbnd->ibnd", (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head
)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = w + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)
return output