in models/expire_span.py [0:0]
def forward(self, query, key, value, memory_hid, memory_counter):
# query = B x M x H
# key, value = B x L x H
B, M, _ = query.size()
_, L, _ = key.size()
aux_loss = 0
key_pe, val_pe = self.key_pe, self.val_pe
spans = None
# compute attention from context
attn = torch.matmul(query, key.transpose(-1, -2)) # B x M x L
# Since some memories are dropped, we cannot switch relative aligment
# anymore. So we need work on absolute position alignment.
# Mask out expired memories
if self.args.expire_span:
mask, expire_loss, spans = self.expire_span(
attn, memory_hid, memory_counter
)
aux_loss = aux_loss + expire_loss
else:
mask = 1.0
# Mask out attention to future steps (and t -> t)
mask = mask * self.mask_causal[:, -L:]
# Compute the effect of position embedding
# Assume no memory is dropped from the previous block.
attn_pos = torch.matmul(query, key_pe) # B x M x L
attn_pos = skew(attn_pos, 0)
attn[:, :, -2 * M :] += attn_pos
# Pre-softmax masking with -inf
mask_pre = torch.zeros_like(attn).masked_fill_(mask.eq(0), float("-inf"))
attn = attn + mask_pre
attn = attn / math.sqrt(self.args.head_dim) # B x M X L
attn = F.softmax(attn, dim=-1)
attn = attn * mask
attn = attn / (attn.sum(-1, keepdim=True) + 1e-8)
attn = self.dropout(attn) # B x M X L
out = torch.matmul(attn, value) # B x M x H
return out, aux_loss, spans