def forward()

in models/expire_span.py [0:0]


    def forward(self, x, h_prev, target=None):
        # x : B x M
        M = x.size(1)
        B = x.size(0)
        H = self.args.hid_sz
        h = self.in_emb(x)  # B x M x H

        c_prev = h_prev[-self.args.nlayers :]
        h_prev = h_prev[: -self.args.nlayers]

        if self.args.expire_span_noisy and self.training:
            block_span_noise = random.random()
            for l in range(self.args.nlayers):
                if self.get_layer(l).args.expire_span:
                    self.get_layer(
                        l
                    ).attn.attn.expire_span.block_span_noise = block_span_noise

        h_cache = []  # memory including the current block
        c_cache = []  # the distance (in time steps) from the first query
        aux_loss = 0
        counter = torch.linspace(0, -M + 1, steps=M).to(self.args.device)
        counter = counter.view(1, -1).expand(B, -1)  # B x M
        for l in range(self.args.nlayers):
            h_cache.append(torch.cat([h_prev[l], h], dim=1))
            c_cache.append(torch.cat([c_prev[l], counter], dim=1))
            if self.training and self.args.expire_span_layerdrop > random.random():
                # skip this layer, but need to compute spans
                _, _, spans = self.get_layer(l)(h, h_cache[l], c_cache[l])  # B x M x H
            else:
                h, loss, spans = self.get_layer(l)(
                    h, h_cache[l], c_cache[l]
                )  # B x M x H
                aux_loss = aux_loss + loss
            if self.get_layer(l).args.expire_span:
                # Determine which memories can be dropped.
                # Extend spans by the ramp length R because memories are still
                # used during those R steps.
                spans = spans + self.args.expire_span_ramp  # B x L
                # Since spans are measured from the 1st query of this block,
                # subtract M so that they're measured from the next block.
                spans = spans - M
                # Now we can remove memories with span <= 0.
                spans = (spans > 0).float()

                # Do not drop any memory from the current block, so we can
                # compute relative-position embedding for last M steps easily.
                spans[:, -M:].fill_(1)

                # But because of batching, we need drop the same amount of memories.
                # Find out the smallest number of memories-to-drop within a batch.
                num_drop = (spans <= 0).long().sum(-1)
                num_drop_min = num_drop.min().item()
                # dropping arbitrary numbers might cause memory fragmentation,
                # so only drop with increments of mem_sz. Using mem_sz will
                # ensure that the memory size stay within the limit.
                num_drop_min = int(
                    math.floor(num_drop_min / self.args.mem_sz) * self.args.mem_sz
                )
                # Now only drop num_drop_min memories from each sample.
                # Here choose num_drop_min memories with the smallest span.
                #  minor speed ups, only sort when we want to drop
                if num_drop_min != 0:
                    spans_sorted, indices = spans.sort(dim=-1)
                    # from 0 to 1
                    spans_sorted[:, num_drop_min:] = 1
                    span_mask = torch.zeros_like(spans)
                    span_mask.scatter_(-1, indices, spans_sorted)
                    span_mask = span_mask.bool()
                    c_cache[l] = c_cache[l][span_mask].view(B, -1)  # B x L'
                    h_cache[l] = h_cache[l][span_mask].view(B, -1, H)  # B x L' x H
                # increase counter
                c_cache[l] += M
            else:
                attention_lim = self.get_layer(l).args.attn_lim
                # keep the nearest (L - M) tokens
                # B x (L x H)
                h_cache[l] = h_cache[l][:, -attention_lim:]  # B x L' x H
                c_cache[l] = c_cache[l][:, -attention_lim:]  # B x L'
                c_cache[l] += M

        if self.args.pre_norm:
            h = self.out_norm(h)
        out = self.out(h, target)

        h_cache.extend(c_cache)

        return out, h_cache, aux_loss