def cache_initprocess()

in models/feedback.py [0:0]


    def cache_initprocess(self, h_cache):
        # compute key and value vectors only once
        key_cache, val_cache = [], []
        for l in range(self.args.nlayers):
            if self.args.feedback:
                h = h_cache[0]  # M x B x H
            else:
                h = h_cache[l]  # M x B x H

            windows = [self.args.attn_lim]
            if self.args.adapt_span:
                if self.args.share_proj_kv:
                    # keys and values differing in their spans
                    windows = []
                    for ll in range(self.args.nlayers):
                        trim_len = self.get_layer(ll).attn.attn.adaptive_span.get_trim_len()
                        windows.append(self.args.attn_lim - trim_len)
                else:
                    # avoid unnecessary computation
                    trim_len = self.get_layer(l).attn.attn.adaptive_span.get_trim_len()
                    h = h[trim_len:]
                    windows = [self.args.attn_lim - trim_len]

            if self.args.pre_norm:
                h = self.get_layer(l).norm1(h)

            key = self.get_layer(l).attn.proj_key(h)  # M x B x H
            val = self.get_layer(l).attn.proj_val(h)  # M x B x H
            key = key.view(h.size(0), -1, self.args.head_dim)  # M x B_K x D
            val = val.view(h.size(0), -1, self.args.head_dim)  # M x B_K x D
            key = SlidingWindowBuffer(key, key.size(0) + self.args.mem_sz, windows)
            val = SlidingWindowBuffer(val, val.size(0) + self.args.mem_sz, windows)
            key_cache.append(key)
            val_cache.append(val)

            if self.args.share_proj_kv:
                # key, values are identical across layers
                break

        # keep the original cache because it will be used in future
        h_cache = {"key": key_cache, "val": val_cache, "hid_prev": h_cache}
        h_cache["hid"] = [[] for _ in range(self.args.nlayers)]
        return h_cache