in persistent_memory.py [0:0]
def forward(self, query, attn):
key = self.key.unsqueeze(0)
val = self.val.unsqueeze(0)
query = query.view((-1, self.nb_heads) + query.size()[1:])
attn_pers = torch.matmul(query, key * math.sqrt(self.head_dim))
attn_pers = attn_pers.view((-1,) + attn_pers.size()[2:])
# compute softmax jointly
attn = torch.cat((attn, attn_pers), dim=-1)
attn = attn / math.sqrt(self.head_dim) # B x M X L_total
attn = F.softmax(attn, dim=-1)
attn_pers = attn[:, :, -key.size(-1):]
attn = attn[:, :, :-key.size(-1)] # B x M X L
# adapt attention span
if self.adaptive_span is not None:
attn = self.adaptive_span(attn, normalize=False)
# normalize the sum jointly!
attn = torch.cat((attn, attn_pers), dim=-1)
attn = attn / (attn.sum(-1, keepdim=True) + 1e-8)
attn_pers = attn[:, :, -key.size(-1):]
attn = attn[:, :, :-key.size(-1)] # B x M X L
attn_pers = self.dropout(attn_pers) # B x M X L
attn_pers = attn_pers.view((-1, self.nb_heads) + attn_pers.size()[1:])
out = torch.matmul(attn_pers, val * math.sqrt(self.size))
out = out.view((-1,) + out.size()[2:])
return attn, out