in lib/xf.py [0:0]
def preproc_qkv(self, Q_bte, K_bte, V_bte):
pad = _required_padding(Q_bte.shape[1], self.stride)
if pad:
Q_bte = F.pad(Q_bte, (0, 0, 0, pad), value=SENTINEL)
K_bte = F.pad(K_bte, (0, 0, 0, pad), value=SENTINEL) if K_bte is not None else None
V_bte = F.pad(V_bte, (0, 0, 0, pad), value=SENTINEL) if V_bte is not None else None
undo = lambda x, pad=pad: x[:, :-pad]
else:
undo = None
if K_bte is not None:
pad = _required_padding(K_bte.shape[1], self.stride)
if pad:
K_bte = F.pad(K_bte, (0, 0, pad, 0), value=SENTINEL)
V_bte = F.pad(V_bte, (0, 0, pad, 0), value=SENTINEL)
assert Q_bte.shape[1] % self.stride == 0
assert K_bte is None or K_bte.shape[1] % self.stride == 0
assert V_bte is None or V_bte.shape[1] % self.stride == 0
Q, postproc, Q_t, Q_pad = self._preproc(Q_bte, "Q")
postproc = misc.compose_undo(undo, postproc)
return (
postproc,
Q,
self._preproc(K_bte, "K", Q_t=Q_t, Q_pad=Q_pad) if K_bte is not None else None,
self._preproc(V_bte, "V", Q_t=Q_t, Q_pad=Q_pad) if V_bte is not None else None,
)