def preproc_qkv()

in lib/xf.py [0:0]


    def preproc_qkv(self, Q_bte, K_bte, V_bte):
        pad = _required_padding(Q_bte.shape[1], self.stride)
        if pad:
            Q_bte = F.pad(Q_bte, (0, 0, 0, pad), value=SENTINEL)
            K_bte = F.pad(K_bte, (0, 0, 0, pad), value=SENTINEL) if K_bte is not None else None
            V_bte = F.pad(V_bte, (0, 0, 0, pad), value=SENTINEL) if V_bte is not None else None
            undo = lambda x, pad=pad: x[:, :-pad]
        else:
            undo = None
        if K_bte is not None:
            pad = _required_padding(K_bte.shape[1], self.stride)
            if pad:
                K_bte = F.pad(K_bte, (0, 0, pad, 0), value=SENTINEL)
                V_bte = F.pad(V_bte, (0, 0, pad, 0), value=SENTINEL)
        assert Q_bte.shape[1] % self.stride == 0
        assert K_bte is None or K_bte.shape[1] % self.stride == 0
        assert V_bte is None or V_bte.shape[1] % self.stride == 0
        Q, postproc, Q_t, Q_pad = self._preproc(Q_bte, "Q")
        postproc = misc.compose_undo(undo, postproc)
        return (
            postproc,
            Q,
            self._preproc(K_bte, "K", Q_t=Q_t, Q_pad=Q_pad) if K_bte is not None else None,
            self._preproc(V_bte, "V", Q_t=Q_t, Q_pad=Q_pad) if V_bte is not None else None,
        )