def cache_postprocess()

in models/feedback.py [0:0]


    def cache_postprocess(self, h_cache, h_all, t):
        for l in range(self.args.nlayers):
            h = h_all[l]
            # Compute key and value from the current state beforehand and
            # put them in the cache to be used in future steps.
            if self.args.pre_norm:
                h = self.get_layer(l).norm1(h)

            key = self.get_layer(l).attn.proj_key(h).view(1, -1, self.args.head_dim)
            val = self.get_layer(l).attn.proj_val(h).view(1, -1, self.args.head_dim)
            h_cache["key"][l].add(key)
            h_cache["val"][l].add(val)
            if not (self.args.feedback and l > 0):
                h = h.squeeze(1).unsqueeze(0)
                h_cache["hid"][l].append(h)
            if self.args.share_proj_kv:
                # no need to compute other layers as they are identical
                break
        return h_cache