in lib/xf.py [0:0]
def preproc_qkv(self, *xs):
q = xs[0].shape[-1]
for x in xs:
assert x.shape[-1] == q, "embedding dimensions do not match"
h = self.h or misc.exact_div(q, self.head_dim)
postproc = functools.partial(self.postproc_a, h=h)
return (postproc, *tuple(split_heads(x, h) for x in xs))