in models/expire_span.py [0:0]
def forward(self, x, h_prev, target=None):
# x : B x M
M = x.size(1)
B = x.size(0)
H = self.args.hid_sz
h = self.in_emb(x) # B x M x H
c_prev = h_prev[-self.args.nlayers :]
h_prev = h_prev[: -self.args.nlayers]
if self.args.expire_span_noisy and self.training:
block_span_noise = random.random()
for l in range(self.args.nlayers):
if self.get_layer(l).args.expire_span:
self.get_layer(
l
).attn.attn.expire_span.block_span_noise = block_span_noise
h_cache = [] # memory including the current block
c_cache = [] # the distance (in time steps) from the first query
aux_loss = 0
counter = torch.linspace(0, -M + 1, steps=M).to(self.args.device)
counter = counter.view(1, -1).expand(B, -1) # B x M
for l in range(self.args.nlayers):
h_cache.append(torch.cat([h_prev[l], h], dim=1))
c_cache.append(torch.cat([c_prev[l], counter], dim=1))
if self.training and self.args.expire_span_layerdrop > random.random():
# skip this layer, but need to compute spans
_, _, spans = self.get_layer(l)(h, h_cache[l], c_cache[l]) # B x M x H
else:
h, loss, spans = self.get_layer(l)(
h, h_cache[l], c_cache[l]
) # B x M x H
aux_loss = aux_loss + loss
if self.get_layer(l).args.expire_span:
# Determine which memories can be dropped.
# Extend spans by the ramp length R because memories are still
# used during those R steps.
spans = spans + self.args.expire_span_ramp # B x L
# Since spans are measured from the 1st query of this block,
# subtract M so that they're measured from the next block.
spans = spans - M
# Now we can remove memories with span <= 0.
spans = (spans > 0).float()
# Do not drop any memory from the current block, so we can
# compute relative-position embedding for last M steps easily.
spans[:, -M:].fill_(1)
# But because of batching, we need drop the same amount of memories.
# Find out the smallest number of memories-to-drop within a batch.
num_drop = (spans <= 0).long().sum(-1)
num_drop_min = num_drop.min().item()
# dropping arbitrary numbers might cause memory fragmentation,
# so only drop with increments of mem_sz. Using mem_sz will
# ensure that the memory size stay within the limit.
num_drop_min = int(
math.floor(num_drop_min / self.args.mem_sz) * self.args.mem_sz
)
# Now only drop num_drop_min memories from each sample.
# Here choose num_drop_min memories with the smallest span.
# minor speed ups, only sort when we want to drop
if num_drop_min != 0:
spans_sorted, indices = spans.sort(dim=-1)
# from 0 to 1
spans_sorted[:, num_drop_min:] = 1
span_mask = torch.zeros_like(spans)
span_mask.scatter_(-1, indices, spans_sorted)
span_mask = span_mask.bool()
c_cache[l] = c_cache[l][span_mask].view(B, -1) # B x L'
h_cache[l] = h_cache[l][span_mask].view(B, -1, H) # B x L' x H
# increase counter
c_cache[l] += M
else:
attention_lim = self.get_layer(l).args.attn_lim
# keep the nearest (L - M) tokens
# B x (L x H)
h_cache[l] = h_cache[l][:, -attention_lim:] # B x L' x H
c_cache[l] = c_cache[l][:, -attention_lim:] # B x L'
c_cache[l] += M
if self.args.pre_norm:
h = self.out_norm(h)
out = self.out(h, target)
h_cache.extend(c_cache)
return out, h_cache, aux_loss