in models/transformer_seq.py [0:0]
def forward(self, query, key, value):
# query = B x M x H
# key, value = B x (M+L) x H
aux_loss = 0
key_pe, val_pe = self.key_pe, self.val_pe
if self.args.adapt_span:
key, value, key_pe, val_pe = self.adaptive_span.trim_memory(
key, value, key_pe, val_pe
)
attn = 0
# compute attention from context
attn = torch.matmul(
query, key.transpose(-1, -2)
) # B x M (dest) x (M+L) (src)
attn = unskew(attn) # B x M x L
# compute the effect of position embedding
attn = attn + torch.matmul(query, key_pe) # B x M x L
attn = attn / math.sqrt(self.args.head_dim) # B x M X L
attn = F.softmax(attn, dim=-1)
if self.args.adapt_span:
attn = self.adaptive_span(attn)
attn = attn / (attn.sum(-1, keepdim=True) + 1e-8)
attn = self.dropout(attn) # B x M X L
out = 0
attn_cont = skew(attn, 0) # B x M X (L+M)
out = out + torch.matmul(attn_cont, value) # B x M x H
return out, aux_loss