def trim_memory()

in modules/adaptive_span.py [0:0]


    def trim_memory(self, key, value, key_pe, val_pe):
        trim_len = self.get_trim_len()
        if key is not None:
            if self.args.feedback:
                cache_size = key.size(1)
            else:
                cache_size = key.size(1) - self.args.mem_sz
            trim_len_cache = trim_len - (self.size - cache_size)
            if self.args.feedback:
                # keys and values must have cut to the right sizes beforehand.
                # Also adapt_span_cache=False, so cache can't be shorter.
                assert trim_len_cache == 0
            if trim_len_cache > 0:
                key = key[:, trim_len_cache:, :]
                value = value[:, trim_len_cache:, :]
            elif trim_len_cache < 0:
                print(
                    "warning: cache is too short. cache_size={} trim_len={}".format(
                        cache_size, trim_len
                    )
                )
                key = F.pad(key, [0, 0, -trim_len_cache, 0])
                value = F.pad(value, [0, 0, -trim_len_cache, 0])
        if trim_len > 0:
            if key_pe is not None:
                key_pe = key_pe[:, :, trim_len:]
            if val_pe is not None:
                val_pe = val_pe[:, trim_len:, :]
        return key, value, key_pe, val_pe