def residual()

in lib/xf.py [0:0]


    def residual(self, X_bte, state):
        X_bte = self.ln_x(X_bte)
        Q_bte = self.q_layer(X_bte)
        K_bte = self.k_layer(X_bte)
        V_bte = self.v_layer(X_bte)
        if state:
            state, K_bte, V_bte = self.update_state(state, K_bte, V_bte)
        postproc_closure, Q_bte, K_bte, V_bte = self.attn.preproc_qkv(Q_bte, K_bte, V_bte)
        extra_btT = self.relattn_logits(X_bte, K_bte.shape[1]) if self.relattn else None
        A_bte = attention(
            Q_bte,
            K_bte,
            V_bte,
            mask=self.attn.mask,
            extra_btT=extra_btT,
            maxlen=self.maxlen,
            dtype=self.dtype,
            check_sentinel=isinstance(self.attn, StridedAttn),
            use_muP_factor=self.use_muP_factor,
        )
        A_bte = postproc_closure(A_bte)
        Aproj_bte = self.proj_layer(A_bte)
        return Aproj_bte, state