def forward()

in models/compressive.py [0:0]


    def forward(self, query, key, value, ckey, cvalue):
        # query = B x M x H
        # key, value = B x (M+L) x H
        # ckey, cvalue = B x M/c+C x H
        aux_loss = 0
        B, M, _ = query.size()
        c = self.args.compress_rate
        C = self.args.compress_size // self.args.compress_rate
        assert M % c == 0

        attn = 0
        # compute attention from context
        attn = torch.matmul(
            query, key.transpose(-1, -2)
        )  # B x M (dest) x (M+L) (src)
        attn = unskew(attn)  # B x M x L

        # compressed memory attention
        cattn = torch.matmul(query, ckey.transpose(-1, -2))  # B x M x M/c+C
        # Note that there is 1 extra memory. This ensure that two memories
        # overlaps without any gap.
        cattn = unskew_step(cattn, c)  # B x M x C+1
        attn = torch.cat([cattn, attn], dim=-1)  # B x M x C+L+1

        # compute the effect of position embedding
        attn = attn + torch.matmul(query, self.key_pe)  # B x M x C+L+1

        attn = attn / math.sqrt(self.args.head_dim)  # B x M X C+L+1
        attn = F.softmax(attn, dim=-1)

        attn = self.dropout(attn)  # B x M X C+L+1

        out = 0

        # compressed memory output
        cattn = attn[:, :, :C+1]  # B x M x C+1
        attn = attn[:, :, C+1:]
        cattn = skew_step(cattn, c, 0)  # B x M x M/c+C
        out = out + torch.matmul(cattn, cvalue)  # B x M x H

        attn_cont = skew(attn, 0)  # B x M X (L+M)
        out = out + torch.matmul(attn_cont, value)  # B x M x H

        return out, aux_loss