def forward()

in models/expire_span.py [0:0]


    def forward(self, query, key, value, memory_hid, memory_counter):
        # query = B x M x H
        # key, value = B x L x H
        B, M, _ = query.size()
        _, L, _ = key.size()
        aux_loss = 0
        key_pe, val_pe = self.key_pe, self.val_pe
        spans = None

        # compute attention from context
        attn = torch.matmul(query, key.transpose(-1, -2))  # B x M x L
        # Since some memories are dropped, we cannot switch relative aligment
        # anymore. So we need work on absolute position alignment.

        # Mask out expired memories
        if self.args.expire_span:
            mask, expire_loss, spans = self.expire_span(
                attn, memory_hid, memory_counter
            )
            aux_loss = aux_loss + expire_loss
        else:
            mask = 1.0

        # Mask out attention to future steps (and t -> t)
        mask = mask * self.mask_causal[:, -L:]

        # Compute the effect of position embedding
        # Assume no memory is dropped from the previous block.
        attn_pos = torch.matmul(query, key_pe)  # B x M x L
        attn_pos = skew(attn_pos, 0)
        attn[:, :, -2 * M :] += attn_pos

        # Pre-softmax masking with -inf
        mask_pre = torch.zeros_like(attn).masked_fill_(mask.eq(0), float("-inf"))
        attn = attn + mask_pre

        attn = attn / math.sqrt(self.args.head_dim)  # B x M X L
        attn = F.softmax(attn, dim=-1)
        attn = attn * mask
        attn = attn / (attn.sum(-1, keepdim=True) + 1e-8)

        attn = self.dropout(attn)  # B x M X L

        out = torch.matmul(attn, value)  # B x M x H

        return out, aux_loss, spans