in models.py [0:0]
def forward(self, query, key, value, key_pe):
# query size = B x M x H
# key, value sizes = B x (M+L) x H
if self.adapt_span_enabled:
# [optional] trim out memory to reduce unnecessary computation
key, value, key_pe = self.adaptive_span.trim_memory(
query, key, value, key_pe)
# compute attention from context
# B x M (dest) x (M+L) (src)
attn_cont = torch.matmul(query, key.transpose(-1, -2))
attn_cont = _unskew(attn_cont) # B x M x L
# compute the effect of position embedding
attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
attn = attn_cont + attn_pos
if self.persistent_memory is not None:
attn, pers_mem_out = self.persistent_memory(query, attn)
else:
attn = attn / math.sqrt(self.hidden_size) # B x M X L_pos
attn = F.softmax(attn, dim=-1)
if self.adapt_span_enabled:
# trim attention lengths according to the learned span
attn = self.adaptive_span(attn)
attn = self.dropout(attn) # B x M X L_pos
attn_cont = _skew(attn, 0) # B x M X (L+M)
out = torch.matmul(attn_cont, value) # B x M x H
if self.persistent_memory is not None:
out = out + pers_mem_out
return out