in models/compressive.py [0:0]
def forward(self, query, key, value, ckey, cvalue):
# query = B x M x H
# key, value = B x (M+L) x H
# ckey, cvalue = B x M/c+C x H
aux_loss = 0
B, M, _ = query.size()
c = self.args.compress_rate
C = self.args.compress_size // self.args.compress_rate
assert M % c == 0
attn = 0
# compute attention from context
attn = torch.matmul(
query, key.transpose(-1, -2)
) # B x M (dest) x (M+L) (src)
attn = unskew(attn) # B x M x L
# compressed memory attention
cattn = torch.matmul(query, ckey.transpose(-1, -2)) # B x M x M/c+C
# Note that there is 1 extra memory. This ensure that two memories
# overlaps without any gap.
cattn = unskew_step(cattn, c) # B x M x C+1
attn = torch.cat([cattn, attn], dim=-1) # B x M x C+L+1
# compute the effect of position embedding
attn = attn + torch.matmul(query, self.key_pe) # B x M x C+L+1
attn = attn / math.sqrt(self.args.head_dim) # B x M X C+L+1
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn) # B x M X C+L+1
out = 0
# compressed memory output
cattn = attn[:, :, :C+1] # B x M x C+1
attn = attn[:, :, C+1:]
cattn = skew_step(cattn, c, 0) # B x M x M/c+C
out = out + torch.matmul(cattn, cvalue) # B x M x H
attn_cont = skew(attn, 0) # B x M X (L+M)
out = out + torch.matmul(attn_cont, value) # B x M x H
return out, aux_loss