in lib/xf.py [0:0]
def residual(self, X_bte, state):
X_bte = self.ln_x(X_bte)
Q_bte = self.q_layer(X_bte)
K_bte = self.k_layer(X_bte)
V_bte = self.v_layer(X_bte)
if state:
state, K_bte, V_bte = self.update_state(state, K_bte, V_bte)
postproc_closure, Q_bte, K_bte, V_bte = self.attn.preproc_qkv(Q_bte, K_bte, V_bte)
extra_btT = self.relattn_logits(X_bte, K_bte.shape[1]) if self.relattn else None
A_bte = attention(
Q_bte,
K_bte,
V_bte,
mask=self.attn.mask,
extra_btT=extra_btT,
maxlen=self.maxlen,
dtype=self.dtype,
check_sentinel=isinstance(self.attn, StridedAttn),
use_muP_factor=self.use_muP_factor,
)
A_bte = postproc_closure(A_bte)
Aproj_bte = self.proj_layer(A_bte)
return Aproj_bte, state