in optimum/quanto/tensor/weights/awq/packed.py [0:0]
def pack_v2(unpacked: torch.Tensor) -> torch.Tensor:
"""
Pack uint4 weights in an int16 tensor as expected by AWQ second generation mixed mm kernel
As compared to the standard packing, this adds three specific formatting:
- permute rows to counter implicit permutation on Turing and Ampere architecture,
- permute rows for faster dequantization,
- interleave groups of 'interleave' rows for efficient parallel processing.
Note that this formatting expects a group size of 128.
Args:
unpacked (`torch.Tensor`):
The un-packed `torch.uint8` tensor
Returns:
A int16 `torch.Tensor`.
"""
assert unpacked.device.type in ["cuda", "xpu"]
assert unpacked.ndim == 2
N, K = unpacked.shape
# These two values are hard-coded in the optimized kernels:
# - I represents the 'interleave', i.e. the number of values packed at a single coordinate (16 bits / 4 bits),
# - S represents the 'kernel stride', and is related to the group size (TBC).
I = 4
S = 64
# 1. For faster dequantization, the tensor rows must be permuted as explained in:
# https://github.com/NVIDIA/TensorRT-LLM/blob/035b99e0d09d4f2dfdb949810cf7245112aa4165/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp#L161
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ...] => [0, 1, 8, 9, 16, 17, 24, 25, ...]
packed = unpacked.reshape(N, K // 32, 4, 4, 2).permute(0, 1, 3, 2, 4)
# Reorder each 8 weights for fast dequantization
# From: "Who Says Elephants Can’t Run: Bringing Large Scale MoE Models into Cloud Scale Production"
# https://arxiv.org/pdf/2211.10017
# [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
packed = packed.permute(0, 1, 2, 4, 3)
packed = packed.reshape(N, K)
# 2. For efficient parallelization, the rows are grouped and interleaved by blocks of kstride into a single row, as explained in:
# https://github.com/NVIDIA/TensorRT-LLM/blob/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h#L69
# interleaving (N, K) -> (N // I, I, K // S, S)
packed = packed.reshape(N // I, I, K // S, S)
# transpose (N // I, I, K // S, S) -> (N // I, K // S, I, S)
packed = packed.permute(0, 2, 1, 3)
# reshape (N // I, K // S, I, S) -> (N // I, K // S, S, I)
packed = packed.reshape(N // I, K // S, S, I)
# Packing (N // I, K // S, S, I) -> (N // I, K // S, S)
packed = packed.to(torch.int32)
packed = packed[..., 0] | (packed[..., 1] << 4) | (packed[..., 2] << 8) | (packed[..., 3] << 12)
# Reshape to (N // I, K // S, S) -> (N // I, K)
packed = packed.reshape(N // I, K)
return packed.to(torch.int16).contiguous()