in modules/adaptive_span.py [0:0]
def trim_memory(self, key, value, key_pe, val_pe):
trim_len = self.get_trim_len()
if key is not None:
if self.args.feedback:
cache_size = key.size(1)
else:
cache_size = key.size(1) - self.args.mem_sz
trim_len_cache = trim_len - (self.size - cache_size)
if self.args.feedback:
# keys and values must have cut to the right sizes beforehand.
# Also adapt_span_cache=False, so cache can't be shorter.
assert trim_len_cache == 0
if trim_len_cache > 0:
key = key[:, trim_len_cache:, :]
value = value[:, trim_len_cache:, :]
elif trim_len_cache < 0:
print(
"warning: cache is too short. cache_size={} trim_len={}".format(
cache_size, trim_len
)
)
key = F.pad(key, [0, 0, -trim_len_cache, 0])
value = F.pad(value, [0, 0, -trim_len_cache, 0])
if trim_len > 0:
if key_pe is not None:
key_pe = key_pe[:, :, trim_len:]
if val_pe is not None:
val_pe = val_pe[:, trim_len:, :]
return key, value, key_pe, val_pe