in optimum/quanto/tensor/weights/marlin/int4/packed.py [0:0]
def pack(unpacked: torch.Tensor):
w = unpacked
N, K = w.shape
w = unpacked.t()
# 16 == tile size, marlin uses 16x16 tile, so 16x16 grouping via interleaving
w = w.reshape((K // 16, 16, N // 16, 16))
w = w.permute((0, 2, 1, 3))
w = w.reshape((K // 16, N * 16))
res = w
# _perm.numel() == 1024 == 4 16x16, permute weights with 4 16x16 unit for efficient mma + dequant
res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)
p = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
res = res.cpu().numpy().astype(np.uint32)
for i in range(8):
p |= res[:, i::8] << 4 * i
p = torch.from_numpy(p.astype(np.int32)).to(w.device)
return p