def pack()

in optimum/quanto/tensor/weights/marlin/int4/packed.py [0:0]


def pack(unpacked: torch.Tensor):
    w = unpacked
    N, K = w.shape
    w = unpacked.t()
    # 16 == tile size, marlin uses 16x16 tile, so 16x16 grouping via interleaving
    w = w.reshape((K // 16, 16, N // 16, 16))
    w = w.permute((0, 2, 1, 3))
    w = w.reshape((K // 16, N * 16))
    res = w
    # _perm.numel() == 1024 == 4 16x16, permute weights with 4 16x16 unit for efficient mma + dequant
    res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)
    p = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
    res = res.cpu().numpy().astype(np.uint32)
    for i in range(8):
        p |= res[:, i::8] << 4 * i
    p = torch.from_numpy(p.astype(np.int32)).to(w.device)
    return p