def forward()

in persistent_memory.py [0:0]


    def forward(self, query, attn):
        key = self.key.unsqueeze(0)
        val = self.val.unsqueeze(0)

        query = query.view((-1, self.nb_heads) + query.size()[1:])
        attn_pers = torch.matmul(query, key * math.sqrt(self.head_dim))
        attn_pers = attn_pers.view((-1,) + attn_pers.size()[2:])

        # compute softmax jointly
        attn = torch.cat((attn, attn_pers), dim=-1)
        attn = attn / math.sqrt(self.head_dim) # B x M X L_total
        attn = F.softmax(attn, dim=-1)
        attn_pers = attn[:, :, -key.size(-1):]
        attn = attn[:, :, :-key.size(-1)]  # B x M X L

        # adapt attention span
        if self.adaptive_span is not None:
            attn = self.adaptive_span(attn, normalize=False)
            # normalize the sum jointly!
            attn = torch.cat((attn, attn_pers), dim=-1)
            attn = attn / (attn.sum(-1, keepdim=True) + 1e-8)
            attn_pers = attn[:, :, -key.size(-1):]
            attn = attn[:, :, :-key.size(-1)]  # B x M X L

        attn_pers = self.dropout(attn_pers) # B x M X L

        attn_pers = attn_pers.view((-1, self.nb_heads) + attn_pers.size()[1:])
        out = torch.matmul(attn_pers, val * math.sqrt(self.size))
        out = out.view((-1,) + out.size()[2:])
        return attn, out