in models/feedback.py [0:0]
def cache_postprocess(self, h_cache, h_all, t):
for l in range(self.args.nlayers):
h = h_all[l]
# Compute key and value from the current state beforehand and
# put them in the cache to be used in future steps.
if self.args.pre_norm:
h = self.get_layer(l).norm1(h)
key = self.get_layer(l).attn.proj_key(h).view(1, -1, self.args.head_dim)
val = self.get_layer(l).attn.proj_val(h).view(1, -1, self.args.head_dim)
h_cache["key"][l].add(key)
h_cache["val"][l].add(val)
if not (self.args.feedback and l > 0):
h = h.squeeze(1).unsqueeze(0)
h_cache["hid"][l].append(h)
if self.args.share_proj_kv:
# no need to compute other layers as they are identical
break
return h_cache