def prepare_adapt_span()

in models/feedback.py [0:0]


    def prepare_adapt_span(self):
        if self.args.adapt_span:
            # compute adaptive-span mask once per block for efficiency
            _, _, key_pe, val_pe = self.adaptive_span.trim_memory(
                None, None, self.key_pe, self.val_pe
            )
            if key_pe is not None:
                self.key_pe_trimmed = key_pe.squeeze(0)
            if val_pe is not None:
                self.val_pe_trimmed = val_pe.squeeze(0)
            trim_len = self.adaptive_span.get_trim_len()
            self.adaptive_span.mask.prepare_mask(self.args.attn_lim - trim_len)
        else:
            if self.key_pe is not None:
                self.key_pe_trimmed = self.key_pe.squeeze(0)
            if self.val_pe is not None:
                self.val_pe_trimmed = self.val_pe.squeeze(0)