server/text_generation_server/layers/compressed_tensors/w8a8_int.py (196 lines of code) (raw):

from typing import List, Optional, Union, TypeVar from dataclasses import dataclass from loguru import logger import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationType from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader if SYSTEM == "cuda": quantization = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) else: quantization = None class W8A8IntLoader(WeightsLoader): """ Loader for w8a8 integer compressed-tensors parameters. """ def __init__( self, *, input_args: Optional[QuantizationArgs], weight_args: QuantizationArgs, ): if weight_args.type != QuantizationType.INT and weight_args.num_bits != 8: raise ValueError( f"{type(self).__name__} only supports w8a8 int checkpoints" ) if not weight_args.symmetric: raise ValueError("Checkpoints with asymmetric weights are not supported") self.load_weight_scale = not weight_args.dynamic if input_args is not None: self.input_symmetric = input_args.symmetric if not input_args.dynamic: log_once( logger.warning, "Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).", ) else: self.input_symmetric = True def __str__(self) -> str: def scale_to_str(scale): return "static" if scale else "dynamic" def symmetric_to_str(symmetric): return "symmetric" if symmetric else "asymmetric" return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_str(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))" def get_weights(self, weights: "Weights", prefix: str): w = weights.get_tensor(f"{prefix}.weight", to_dtype=False) weight_scale = None if self.load_weight_scale: weight_scale = weights.get_tensor( f"{prefix}.weight_scale", to_dtype=False ).reshape(-1) return Int8Weight( input_symmetric=self.input_symmetric, weight=w, weight_scale=weight_scale, ) def get_weights_col_packed( self, weights: Weights, prefix: str, block_sizes: Union[int, List[int]], ): w = weights.get_packed_sharded( f"{prefix}.weight", dim=0, block_sizes=block_sizes, to_dtype=False ) weight_scale = None if self.load_weight_scale: weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) if weight_scale.numel() > 1: weight_scale = weights.get_packed_sharded( f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False, ) weight_scale = weight_scale.reshape(-1) return Int8Weight( input_symmetric=self.input_symmetric, weight=w, weight_scale=weight_scale, ) def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): w = [ weights.get_sharded(f"{p}.weight", dim=0, to_dtype=False) for p in prefixes ] shapes = [x.shape for x in w] w = torch.cat(w, dim=dim) weight_scale = None if self.load_weight_scale: weight_scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) for p, shape in zip(prefixes, shapes) ] weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1) return Int8Weight( input_symmetric=self.input_symmetric, weight=w, weight_scale=weight_scale, ) def get_weights_row(self, weights: "Weights", prefix: str): w = weights.get_sharded(f"{prefix}.weight", dim=1, to_dtype=False) weight_scale = None if self.load_weight_scale: weight_scale = weights.get_tensor( f"{prefix}.weight_scale", to_dtype=False ).reshape(-1) return Int8Weight( input_symmetric=self.input_symmetric, weight=w, weight_scale=weight_scale, ) OtherT = TypeVar("OtherT") def _get_tensor_or_else( weights: Weights, prefix: str, other: OtherT ) -> Union[torch.Tensor, OtherT]: # Even if a checkpoint uses e.g. zero-points, they can be elided: # https://github.com/neuralmagic/compressed-tensors/blob/db6ccb25b265e8370813ecab5e95714a6728b5a6/src/compressed_tensors/compressors/quantized_compressors/base.py#L105 if weights.has_tensor(prefix): return weights.get_tensor(prefix, to_dtype=False) else: return other @dataclass class Int8Weight(Weight): input_symmetric: bool weight: torch.Tensor weight_scale: Optional[torch.Tensor] def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: assert quantization is not None qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight) return W8A8IntLinear( bias=bias, input_symmetric=self.input_symmetric, weight=qweight, weight_scale=weight_scale, ) else: return W8A8IntLinear( bias=bias, input_symmetric=self.input_symmetric, weight=self.weight, weight_scale=self.weight_scale, ) class W8A8IntLinear(torch.nn.Module): def __init__( self, *, bias: Optional[torch.Tensor], input_symmetric: bool, weight: torch.Tensor, weight_scale: torch.Tensor, ): super().__init__() weight_scale = weight_scale.to(torch.float32) self.bias = bias self.input_symmetric = input_symmetric # cutlass kernels require transposed weights. self.weight = weight.t() self.weight_scale = weight_scale if input_symmetric: self.zero_point_adj = None else: # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md#scaledepilogueazp self.zero_point_adj = self.weight.sum( dim=0, keepdim=True, dtype=torch.int32 ) def forward(self, input: torch.Tensor) -> torch.Tensor: assert quantization is not None qinput, input_scale, input_zero_point = quantization.scaled_int8_quant( input=input, scale=None, azp=None, symmetric=self.input_symmetric, ) if self.input_symmetric: return quantization.cutlass_scaled_mm( a=qinput, b=self.weight, scale_a=input_scale, scale_b=self.weight_scale, out_dtype=input.dtype, bias=self.bias, ) else: assert ( self.zero_point_adj is not None and input_scale is not None and (self.input_symmetric or input_zero_point is not None) ) return quantization.cutlass_scaled_mm_azp( a=qinput, b=self.weight, scale_a=input_scale, scale_b=self.weight_scale, out_dtype=input.dtype, azp_adj=self.zero_point_adj, azp=input_zero_point, bias=self.bias, )