in lib/xf.py [0:0]
def update_state(self, state, K_bte, V_bte):
def append(prev, new):
"""
Given `prev` keys from cache, and `new` keys,
returns (cache, full), where
- cache goes into the output state, length chosen so that on the
next timestep, there are enough cached timesteps to get the full
context of lenth self.maxlen.
- full is used for the current forward pass, with length chosen so
that the first timestep new[:, 0] gets to see a context of
self.maxlen.
"""
tprev = prev.shape[1]
startfull = max(tprev - self.cache_keep_len, 0)
full = th.cat([prev[:, startfull:], new], dim=1)
outstate = full[:, max(full.shape[1] - (self.cache_keep_len), 0) :]
# To see that the preceding slicing is correct, consider the case
# that maxlen==1. Then `full` only consists of `new`, and
# `outstate` is empty
return outstate, full
instate_K, instate_V = state
outstate_K, K_bte = append(instate_K, K_bte)
outstate_V, V_bte = append(instate_V, V_bte)
assert outstate_K.shape[-2] <= self.cache_keep_len
return (outstate_K, outstate_V), K_bte, V_bte