def preproc_qkv()

in lib/xf.py [0:0]


    def preproc_qkv(self, *xs):
        q = xs[0].shape[-1]
        for x in xs:
            assert x.shape[-1] == q, "embedding dimensions do not match"
        h = self.h or misc.exact_div(q, self.head_dim)
        postproc = functools.partial(self.postproc_a, h=h)
        return (postproc, *tuple(split_heads(x, h) for x in xs))