in models/expire_span.py [0:0]
def forward(self, attn, memory_hid, current_counter):
# Since we're dropping memories, here L can be smaller than attn_lim
# attn : B x M x L
# memory_hid : B' x L x H'
# current_counter : B' x L
B, M, L = attn.size()
# Compute the maximum span (number of steps) a memory can stay
max_span = self.span_predictor(
memory_hid / self.args.expire_span_pre_div
).squeeze(
-1
) # B' x L
max_span = torch.sigmoid(max_span) * self.size
if self.training:
# Again, measure only for the current block.
self.avg_span_log.append(max_span[:, -M:].mean().item())
self.max_span_log = max(self.max_span_log, max_span[:, -M:].max().item())
# Compute remaining spans measured from the 1st query.
remaining_offset = max_span - current_counter # B' x L
# add noise
if self.args.expire_span_noisy and self.training:
noisy_span_lim = self.block_span_noise * self.size
max_span_noisy = max_span.clamp(max=noisy_span_lim)
remaining_offset_noisy = max_span_noisy - current_counter # B' x L
else:
remaining_offset_noisy = remaining_offset
# Remaining spans measured from all queries.
remaining_span = remaining_offset_noisy.unsqueeze(1) # B' x 1 x L
remaining_span = remaining_span.expand(-1, M, -1).contiguous() # B' x M x L
remaining_span = remaining_span - torch.linspace(0, M - 1, M).view(1, -1, 1).to(
device=remaining_span.device
)
# Compute the mask:
# mask=1 if remaining_span >= 0
# mask=0 if remaining_span <= -ramp_size
# In between, linearly interpolate between those two.
mask = remaining_span / self.args.expire_span_ramp + 1.0
mask = mask.clamp(0, 1) # B' x M x L
# Loss to encourage spans to be small.
# Compute the loss for memories only under the ramp
ramp_mask = (mask > 0) * (mask < 1) # B' x M x L
span_loss = remaining_span * ramp_mask.float() # B' x M x L
loss = span_loss.sum(dim=-1).sum(dim=-1) # B'
# Scale to match with previous versions:
# - Divide by R because each memory has R losses applied
# - Divide by M because we're avering over a block
loss = loss / self.args.expire_span_ramp / M
loss = loss * self.args.expire_span_loss # B'
# Replicate for each head.
mask = mask.unsqueeze(1) # B' x 1 x M x L
mask = mask.expand(-1, self.args.nheads, -1, -1) # B' x K x M x L
mask = mask.flatten(0, 1) # B x M x L
return mask, loss, remaining_offset