def forward()

in models/expire_span.py [0:0]


    def forward(self, attn, memory_hid, current_counter):
        # Since we're dropping memories, here L can be smaller than attn_lim
        # attn : B x M x L
        # memory_hid : B' x L x H'
        # current_counter : B' x L
        B, M, L = attn.size()

        # Compute the maximum span (number of steps) a memory can stay
        max_span = self.span_predictor(
            memory_hid / self.args.expire_span_pre_div
        ).squeeze(
            -1
        )  # B' x L
        max_span = torch.sigmoid(max_span) * self.size

        if self.training:
            # Again, measure only for the current block.
            self.avg_span_log.append(max_span[:, -M:].mean().item())
            self.max_span_log = max(self.max_span_log, max_span[:, -M:].max().item())

        # Compute remaining spans measured from the 1st query.
        remaining_offset = max_span - current_counter  # B' x L

        # add noise
        if self.args.expire_span_noisy and self.training:
            noisy_span_lim = self.block_span_noise * self.size
            max_span_noisy = max_span.clamp(max=noisy_span_lim)
            remaining_offset_noisy = max_span_noisy - current_counter  # B' x L
        else:
            remaining_offset_noisy = remaining_offset

        # Remaining spans measured from all queries.
        remaining_span = remaining_offset_noisy.unsqueeze(1)  # B' x 1 x L
        remaining_span = remaining_span.expand(-1, M, -1).contiguous()  # B' x M x L
        remaining_span = remaining_span - torch.linspace(0, M - 1, M).view(1, -1, 1).to(
            device=remaining_span.device
        )

        # Compute the mask:
        #   mask=1 if remaining_span >= 0
        #   mask=0 if remaining_span <= -ramp_size
        #   In between, linearly interpolate between those two.
        mask = remaining_span / self.args.expire_span_ramp + 1.0
        mask = mask.clamp(0, 1)  # B' x M x L

        # Loss to encourage spans to be small.
        # Compute the loss for memories only under the ramp
        ramp_mask = (mask > 0) * (mask < 1)  # B' x M x L
        span_loss = remaining_span * ramp_mask.float()  # B' x M x L
        loss = span_loss.sum(dim=-1).sum(dim=-1)  # B'
        # Scale to match with previous versions:
        # - Divide by R because each memory has R losses applied
        # - Divide by M because we're avering over a block
        loss = loss / self.args.expire_span_ramp / M
        loss = loss * self.args.expire_span_loss  # B'

        # Replicate for each head.
        mask = mask.unsqueeze(1)  # B' x 1 x M x L
        mask = mask.expand(-1, self.args.nheads, -1, -1)  # B' x K x M x L
        mask = mask.flatten(0, 1)  # B x M x L

        return mask, loss, remaining_offset