optimum/quanto/tensor/weights/awq/packed.py (154 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast from copy import copy from enum import Enum import numpy as np import torch from torch.utils import _pytree as pytree from ..packing import unpack_int32_to_uint8 __all__ = ["AWQPackedTensor", "AWQPacking"] AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] def pack(unpacked: torch.Tensor, reorder=False): """ Pack uint4 weights in an int32 tensor as expected by AWQ mixed mm kernel As compared to the standard packing, this adds an optional permutation of the columns for faster dequantization, as explained in "Who Says Elephants Can’t Run: Bringing Large Scale MoE Models into Cloud Scale Production", https://arxiv.org/pdf/2211.10017. Args: unpacked (`torch.Tensor`): The un-packed `torch.uint8` tensor reorder (`bool`): Whether columns should be reordered or not before packing. Returns: A int32 `torch.Tensor`. """ bits = 4 pack_num = 32 // bits packed = torch.zeros(unpacked.shape[0], unpacked.shape[1] // pack_num, dtype=torch.int32, device=unpacked.device) for col in range(unpacked.shape[1] // pack_num): if reorder: order_map = AWQ_ORDER else: order_map = [0, 1, 2, 3, 4, 5, 6, 7] for i in range(pack_num): packed_col = unpacked[:, col * pack_num + order_map[i]].to(torch.int32) packed[:, col] |= packed_col << (i * bits) return packed def reverse_awq_order(t: torch.Tensor): bits = 4 reverse_order_tensor = torch.arange( t.shape[-1], dtype=torch.int32, device=t.device, ) reverse_order_tensor = reverse_order_tensor.reshape(-1, 32 // bits) reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] reverse_order_tensor = reverse_order_tensor.reshape(-1) t = t[:, reverse_order_tensor] return t def unpack(packed: torch.Tensor, reorder=False): """Unpack a packed int32 tensor to a larger uint8 tensor Applies pack operations in reverse order (see pack method for details). Args: packed (`torch.Tensor`): The packed `torch.int32` tensor reorder (`bool`): Whether columns should be reordered or not. Returns: An unpacked uint8 `torch.Tensor` expanded along the second dimension. """ unpacked = unpack_int32_to_uint8(packed, bits=4) if reorder: unpacked = reverse_awq_order(unpacked) return unpacked def pack_v2(unpacked: torch.Tensor) -> torch.Tensor: """ Pack uint4 weights in an int16 tensor as expected by AWQ second generation mixed mm kernel As compared to the standard packing, this adds three specific formatting: - permute rows to counter implicit permutation on Turing and Ampere architecture, - permute rows for faster dequantization, - interleave groups of 'interleave' rows for efficient parallel processing. Note that this formatting expects a group size of 128. Args: unpacked (`torch.Tensor`): The un-packed `torch.uint8` tensor Returns: A int16 `torch.Tensor`. """ assert unpacked.device.type in ["cuda", "xpu"] assert unpacked.ndim == 2 N, K = unpacked.shape # These two values are hard-coded in the optimized kernels: # - I represents the 'interleave', i.e. the number of values packed at a single coordinate (16 bits / 4 bits), # - S represents the 'kernel stride', and is related to the group size (TBC). I = 4 S = 64 # 1. For faster dequantization, the tensor rows must be permuted as explained in: # https://github.com/NVIDIA/TensorRT-LLM/blob/035b99e0d09d4f2dfdb949810cf7245112aa4165/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp#L161 # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ...] => [0, 1, 8, 9, 16, 17, 24, 25, ...] packed = unpacked.reshape(N, K // 32, 4, 4, 2).permute(0, 1, 3, 2, 4) # Reorder each 8 weights for fast dequantization # From: "Who Says Elephants Can’t Run: Bringing Large Scale MoE Models into Cloud Scale Production" # https://arxiv.org/pdf/2211.10017 # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] packed = packed.permute(0, 1, 2, 4, 3) packed = packed.reshape(N, K) # 2. For efficient parallelization, the rows are grouped and interleaved by blocks of kstride into a single row, as explained in: # https://github.com/NVIDIA/TensorRT-LLM/blob/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h#L69 # interleaving (N, K) -> (N // I, I, K // S, S) packed = packed.reshape(N // I, I, K // S, S) # transpose (N // I, I, K // S, S) -> (N // I, K // S, I, S) packed = packed.permute(0, 2, 1, 3) # reshape (N // I, K // S, I, S) -> (N // I, K // S, S, I) packed = packed.reshape(N // I, K // S, S, I) # Packing (N // I, K // S, S, I) -> (N // I, K // S, S) packed = packed.to(torch.int32) packed = packed[..., 0] | (packed[..., 1] << 4) | (packed[..., 2] << 8) | (packed[..., 3] << 12) # Reshape to (N // I, K // S, S) -> (N // I, K) packed = packed.reshape(N // I, K) return packed.to(torch.int16).contiguous() def unpack_v2(packed): """Unpack a packed int16 tensor to a larger uint8 tensor Applies pack operations in reverse order (see pack_v2 method for details). Warning: very slow, to be used for debug only. Args: packed (`torch.Tensor`): The packed `torch.int16` tensor Returns: An unpacked uint8 `torch.Tensor` expanded along the first dimension. """ assert packed.device.type in ["cuda", "xpu"] assert packed.ndim == 2 I = 4 S = 64 N_div_I, K = packed.shape N = N_div_I * I # Reshape (N // I, K) -> (N // I, K // S, S, 1) unpacked = packed.reshape(N // I, K // S, S, 1) # Convert to uint16 (through numpy because not supported by pytorch) unpacked = unpacked.cpu().numpy().astype(np.uint16) # Unpack (N // I, K, S) -> (N // I, K // S, S, I) unpacked = torch.cat( [ torch.tensor((unpacked & 0xF).astype(np.uint8)).to(packed.device), torch.tensor(((unpacked & 0xF0) >> 4).astype(np.uint8)).to(packed.device), torch.tensor(((unpacked & 0xF00) >> 8).astype(np.uint8)).to(packed.device), torch.tensor(((unpacked & 0xF000) >> 12).astype(np.uint8)).to(packed.device), ], axis=-1, ) # reshape (N // I, K // S, S, I) -> (N // I, K // S, I, S) unpacked = unpacked.reshape(N // I, K // S, I, S) # transpose (N // I, K // S, I, S) -> (N // I, I, K // S, S) unpacked = unpacked.permute(0, 2, 1, 3) # deinterleaving (N // I, I, K // S, S) -> (N, K) unpacked = unpacked.reshape(N, K) # Final steps to reorder (see packing code for explaination) unpacked = unpacked.reshape(N, K // 32, 4, 2, 4).permute(0, 1, 2, 4, 3) unpacked = unpacked.permute(0, 1, 3, 2, 4) unpacked = unpacked.reshape(N, K) return unpacked class AWQPacking(Enum): V1 = 1 V2 = 2 class AWQPackedTensor(torch.Tensor): @staticmethod def __new__(cls, data, packing, reorder, size, stride, requires_grad=False): # AWQPackedTensor represents uint8 data and can therefore NEVER require gradient assert data.device.type in ["cuda", "xpu"] assert data.dtype == torch.int32 if packing == AWQPacking.V1 else torch.int16 assert requires_grad is False return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=torch.uint8, device=data.device, requires_grad=requires_grad ) def __init__(self, data, packing, reorder, size, stride, requires_grad=False): self._data = data self._packing = packing self._reorder = reorder def __repr__(self): return f"AWQPackedTensor({self._data}, packing={self._packing}, reorder={self._reorder})" @classmethod def pack(cls, t, packing=AWQPacking.V1, reorder=False): if packing == AWQPacking.V1: data = pack(t, reorder=reorder) else: data = pack_v2(t) # We need to store size and stride to make sure the unpacked data has the correct shape return AWQPackedTensor(data, packing, reorder, t.size(), t.stride()) def unpack(self): if self._packing == AWQPacking.V1: return unpack(self._data, self._reorder) return unpack_v2(self._data) @property def dtype(self): return torch.uint8 def __tensor_flatten__(self): inner_tensors = ["_data"] # Since meta can be used for serialization, use only AST compatible strings meta = { "packing": str(self._packing), "reorder": str(self._reorder), "size": str(list(self.size())), "stride": str(self.stride()), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 1 assert len(meta) == 4 data = inner_tensors["_data"] # Meta should contain only AST compatible strings packing = ast.literal_eval(meta["packing"]) reorder = ast.literal_eval(meta["reorder"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return AWQPackedTensor(data, packing, reorder, size, stride) __torch_function__ = torch._C._disabled_torch_function_impl @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): # Convert back to tensor before calling any operation except detach and move if op.overloadpacket is torch.ops.aten.detach: t = args[0] data = op(t._data) return AWQPackedTensor(data, t._packing, t._reorder, t.size(), t.stride()) elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to): t = args[0] dtype = kwargs.get("dtype", torch.uint8) if dtype != torch.uint8: raise ValueError(f"AWQPackedTensor are torch.uint8 only and cannot be moved to {dtype}.") device = kwargs.get("device", t.device) # AWQPackedTensor can only be moved to CUDA devices if device.type == "cuda": data_kwargs = copy(kwargs) data_kwargs["dtype"] = t._data.dtype data = op(t._data, **data_kwargs) return AWQPackedTensor(data, t._packing, t._reorder, t.size(), t.stride()) args, kwargs = pytree.tree_map_only(AWQPackedTensor, lambda x: x.unpack(), (args, kwargs or {})) return op(*args, **kwargs) def numpy(self): return self.unpack().cpu().numpy()