in models/feedback.py [0:0]
def cache_preprocess(self, h_cache, t):
if self.args.share_proj_kv:
k = h_cache["key"][0].get()
v = h_cache["val"][0].get()
if self.args.adapt_span:
# keys and values differing in their spans
key_cache = [x.transpose(0, 1) for x in k]
val_cache = [x.transpose(0, 1) for x in v]
else:
# there is a single set of keys and values
key_cache = [k[0].transpose(0, 1)] * self.args.nlayers
val_cache = [v[0].transpose(0, 1)] * self.args.nlayers
else:
key_cache = [h.get()[0].transpose(0, 1) for h in h_cache["key"]]
val_cache = [h.get()[0].transpose(0, 1) for h in h_cache["val"]]
return key_cache, val_cache