in models/feedback.py [0:0]
def cache_initprocess(self, h_cache):
# compute key and value vectors only once
key_cache, val_cache = [], []
for l in range(self.args.nlayers):
if self.args.feedback:
h = h_cache[0] # M x B x H
else:
h = h_cache[l] # M x B x H
windows = [self.args.attn_lim]
if self.args.adapt_span:
if self.args.share_proj_kv:
# keys and values differing in their spans
windows = []
for ll in range(self.args.nlayers):
trim_len = self.get_layer(ll).attn.attn.adaptive_span.get_trim_len()
windows.append(self.args.attn_lim - trim_len)
else:
# avoid unnecessary computation
trim_len = self.get_layer(l).attn.attn.adaptive_span.get_trim_len()
h = h[trim_len:]
windows = [self.args.attn_lim - trim_len]
if self.args.pre_norm:
h = self.get_layer(l).norm1(h)
key = self.get_layer(l).attn.proj_key(h) # M x B x H
val = self.get_layer(l).attn.proj_val(h) # M x B x H
key = key.view(h.size(0), -1, self.args.head_dim) # M x B_K x D
val = val.view(h.size(0), -1, self.args.head_dim) # M x B_K x D
key = SlidingWindowBuffer(key, key.size(0) + self.args.mem_sz, windows)
val = SlidingWindowBuffer(val, val.size(0) + self.args.mem_sz, windows)
key_cache.append(key)
val_cache.append(val)
if self.args.share_proj_kv:
# key, values are identical across layers
break
# keep the original cache because it will be used in future
h_cache = {"key": key_cache, "val": val_cache, "hid_prev": h_cache}
h_cache["hid"] = [[] for _ in range(self.args.nlayers)]
return h_cache