def update_state()

in lib/xf.py [0:0]


    def update_state(self, state, K_bte, V_bte):
        def append(prev, new):
            """
            Given `prev` keys from cache, and `new` keys,
            returns (cache, full), where
            - cache goes into the output state, length chosen so that on the
                next timestep, there are enough cached timesteps to get the full
                context of lenth self.maxlen.
            - full is used for the current forward pass, with length chosen so
                that the first timestep new[:, 0] gets to see a context of
                self.maxlen.
            """
            tprev = prev.shape[1]
            startfull = max(tprev - self.cache_keep_len, 0)
            full = th.cat([prev[:, startfull:], new], dim=1)
            outstate = full[:, max(full.shape[1] - (self.cache_keep_len), 0) :]
            # To see that the preceding slicing is correct, consider the case
            # that maxlen==1. Then `full` only consists of `new`, and
            # `outstate` is empty
            return outstate, full

        instate_K, instate_V = state
        outstate_K, K_bte = append(instate_K, K_bte)
        outstate_V, V_bte = append(instate_V, V_bte)
        assert outstate_K.shape[-2] <= self.cache_keep_len
        return (outstate_K, outstate_V), K_bte, V_bte