def pack_v2()

in optimum/quanto/tensor/weights/awq/packed.py [0:0]


def pack_v2(unpacked: torch.Tensor) -> torch.Tensor:
    """
    Pack uint4 weights in an int16 tensor as expected by AWQ second generation mixed mm kernel

    As compared to the standard packing, this adds three specific formatting:

    - permute rows to counter implicit permutation on Turing and Ampere architecture,
    - permute rows for faster dequantization,
    - interleave groups of 'interleave' rows for efficient parallel processing.

    Note that this formatting expects a group size of 128.

    Args:
        unpacked (`torch.Tensor`):
            The un-packed `torch.uint8` tensor

    Returns:
        A int16 `torch.Tensor`.
    """
    assert unpacked.device.type in ["cuda", "xpu"]
    assert unpacked.ndim == 2
    N, K = unpacked.shape
    # These two values are hard-coded in the optimized kernels:
    # - I represents the 'interleave', i.e. the number of values packed at a single coordinate (16 bits / 4 bits),
    # - S represents the 'kernel stride', and is related to the group size (TBC).
    I = 4
    S = 64

    # 1. For faster dequantization, the tensor rows must be permuted as explained in:
    # https://github.com/NVIDIA/TensorRT-LLM/blob/035b99e0d09d4f2dfdb949810cf7245112aa4165/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp#L161
    # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ...] => [0, 1, 8, 9, 16, 17, 24, 25, ...]
    packed = unpacked.reshape(N, K // 32, 4, 4, 2).permute(0, 1, 3, 2, 4)

    # Reorder each 8 weights for fast dequantization
    # From: "Who Says Elephants Can’t Run: Bringing Large Scale MoE Models into Cloud Scale Production"
    # https://arxiv.org/pdf/2211.10017
    # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
    packed = packed.permute(0, 1, 2, 4, 3)
    packed = packed.reshape(N, K)

    # 2. For efficient parallelization, the rows are grouped and interleaved by blocks of kstride into a single row, as explained in:
    # https://github.com/NVIDIA/TensorRT-LLM/blob/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h#L69
    # interleaving (N, K) -> (N // I, I, K // S, S)
    packed = packed.reshape(N // I, I, K // S, S)
    # transpose (N // I, I, K // S, S) -> (N // I, K // S, I, S)
    packed = packed.permute(0, 2, 1, 3)
    # reshape (N // I, K // S, I, S) -> (N // I, K // S, S, I)
    packed = packed.reshape(N // I, K // S, S, I)
    # Packing (N // I, K // S, S, I) -> (N // I, K // S, S)
    packed = packed.to(torch.int32)
    packed = packed[..., 0] | (packed[..., 1] << 4) | (packed[..., 2] << 8) | (packed[..., 3] << 12)
    # Reshape to (N // I, K // S, S) -> (N // I, K)
    packed = packed.reshape(N // I, K)
    return packed.to(torch.int16).contiguous()