in models/feedback.py [0:0]
def prepare_adapt_span(self):
if self.args.adapt_span:
# compute adaptive-span mask once per block for efficiency
_, _, key_pe, val_pe = self.adaptive_span.trim_memory(
None, None, self.key_pe, self.val_pe
)
if key_pe is not None:
self.key_pe_trimmed = key_pe.squeeze(0)
if val_pe is not None:
self.val_pe_trimmed = val_pe.squeeze(0)
trim_len = self.adaptive_span.get_trim_len()
self.adaptive_span.mask.prepare_mask(self.args.attn_lim - trim_len)
else:
if self.key_pe is not None:
self.key_pe_trimmed = self.key_pe.squeeze(0)
if self.val_pe is not None:
self.val_pe_trimmed = self.val_pe.squeeze(0)