tinynn/llm_quant/modules.py (180 lines of code) (raw):

import sys import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from transformers.models.llama.modeling_llama import LlamaAttention from tinynn.llm_quant.llama import LlamaAttentionFused from tinynn.util.util import get_logger from .util import _init_patch_easyquant, get_submodule_with_parent_from_name log = get_logger(__name__, 'INFO') SPEEDUP = True try: if sys.platform == "win32": _init_patch_easyquant() from easyquant import ( decompress_int4, decompress_int8, quantize_per_token, gemm, dequantize_bias_per_token, dequantize_per_token, ) except (ImportError, OSError): log.warning('easyquant is not installed, the inference performance may be degraded') SPEEDUP = False def compress_int(data_tensor, bit_width, per_channel=True, per_token=False): # use [-127, 127] as 8-bit quant range q_max = 2 ** (bit_width - 1) - 1 q_min = -q_max assert (per_channel and per_token) is False if per_channel: # for weight, use w_max/quant_max as scale, and convert weight to int8 to save memory. scale = 2 * (data_tensor.abs().max(dim=-1).values.float() / (2**bit_width - 1)) quantized_tensor = torch.clamp(torch.round(data_tensor.float() / scale[:, None]), q_min, q_max).to(torch.int8) elif per_token: # per-token quantization scales = data_tensor.abs().max(dim=-1).values.float() / q_max if len(data_tensor.shape) == 3: scales = scales[:, :, None] elif len(data_tensor.shape) == 2: scales = scales[:, None] else: assert False quantized_tensor = torch.clamp(torch.round(data_tensor.float() / scales.float()), q_min, q_max).to(torch.int8) scale = scales else: # per_tensor quantization scale = data_tensor.abs().max().float() / q_max quantized_tensor = torch.clamp(torch.round(data_tensor.float() / scale.float()), q_min, q_max).to(torch.int8) return scale, quantized_tensor class QLinear(nn.Module): def __init__(self, fc: nn.Linear, quant_mode: str): super().__init__() assert quant_mode in ("weight4", "weight8", "dynamic") if quant_mode == 'weight4': weight_bit_width = 4 else: weight_bit_width = 8 self.weight_bit_width = weight_bit_width self.quant_mod = quant_mode self.in_features = fc.in_features self.out_features = fc.out_features bias = None if fc.bias is None else fc.bias.data # compress weight by given bit, use per-channel and [-127,127]/[-7,7] to clamp scale, weight_q = compress_int(fc.weight.data, weight_bit_width) if self.in_features % 4 != 0 and quant_mode == 'dynamic': weight_q = F.pad(weight_q, (0, 4 - self.in_features % 4)) if self.weight_bit_width == 4: weight_shape = weight_q.shape assert len(weight_shape) == 2 assert weight_shape[1] % 2 == 0 pre_packed = weight_q.view(weight_shape[0], weight_shape[1] // 2, 2) weight_q = ((pre_packed[..., 0] & 0b00001111) << 4) | (pre_packed[..., 1] & 0b00001111) self.weight = nn.Parameter(weight_q, requires_grad=False) self.weight_scale = nn.Parameter(scale, requires_grad=False) self.bias = nn.Parameter(bias, requires_grad=False) if bias is not None else None fc.weight = None fc.bias = None def forward(self, input: Tensor) -> Tensor: input_device = input.device input_dtype = input.dtype input_shape = input.shape if self.quant_mod == 'static': assert False, f'{self.quant_mod} not supported' else: if self.quant_mod == 'weight4': if SPEEDUP: weight_fp = torch.empty( (self.out_features, self.in_features), dtype=torch.float16, device=input.device ) decompress_int4(weight_fp, self.weight, self.weight_scale) else: weight_fp = ( torch.stack((self.weight >> 4, self.weight << 4 >> 4), -1) .view(self.out_features, self.in_features) .to(dtype=torch.float32) * self.weight_scale[:, None] ).to(dtype=torch.half) elif self.quant_mod == 'weight8': if SPEEDUP: weight_fp = torch.empty_like(self.weight.data, dtype=input_dtype, device=input_device) decompress_int8(weight_fp, self.weight, self.weight_scale) else: weight_fp = (self.weight.to(dtype=torch.float32) * self.weight_scale[:, None]).to(dtype=torch.half) if 'dynamic' in self.quant_mod: if SPEEDUP: # the real dynamic quantization process, first quantize input to int8, then do int8Gemm calculation, # and finally dequantize the output to float input_viewed = input.view(-1, input_shape[-1]) # pad self.weight to 4x padding_num = 4 - self.in_features % 4 if self.in_features % 4 != 0 else 0 # init easyquant kernels' output input_q = torch.empty( (input_viewed.shape[0], input_viewed.shape[1] + padding_num), dtype=torch.int8, device=input_device, ) scale_shape = input_viewed.shape[0] if 'token' in self.quant_mod else 1 input_scale = torch.zeros(scale_shape, device=input_device) out_q = torch.empty( (int(input_viewed.shape[0]), self.out_features), dtype=torch.int32, device=input_device ) output = torch.empty_like(out_q, dtype=torch.float16, device=input_device) # use easyquant kernels to accelerate computation quantize_per_token(input_q, input_viewed, input_scale) gemm(out_q, input_q, self.weight) if self.bias is not None: dequantize_bias_per_token(output, out_q, input_scale, self.weight_scale, self.bias) else: dequantize_per_token(output, out_q, input_scale, self.weight_scale) output = output.view(input_shape[:-1] + (output.shape[-1],)) else: # simulate quantization input_scale, input_q = compress_int(input, 8, per_channel=False, per_token=True) if self.in_features % 4 != 0: output = F.linear( input_q.float(), self.weight[:, : self.in_features % 4 - 4].float(), self.bias ) else: output = F.linear(input_q.float(), self.weight.float(), self.bias) output = (output.float() * (self.weight_scale * input_scale.view(-1, 1))).half() else: input_fq = input output = F.linear(input_fq, weight_fp, self.bias) return output class TDQLinear_noinit(QLinear): def forward(self, input: Tensor) -> Tensor: input_shape = input.shape bs, seq, _ = input_shape input_device = input.device input_viewed = input.view(-1, self.in_features) # pad self.weight to 4x padding_num = 4 - self.in_features % 4 if self.in_features % 4 != 0 else 0 input_q = torch.empty( (input_viewed.shape[0], self.in_features + padding_num), dtype=torch.int8, device=input_device ) input_scale = torch.empty(bs * seq, device=input_device) out_q = torch.empty((bs * seq, self.out_features), dtype=torch.int32, device=input_device) output = torch.empty_like(out_q, dtype=torch.float16, device=input_device) quantize_per_token(input_q, input_viewed, input_scale) gemm(out_q, input_q, self.weight) dequantize_per_token(output, out_q, input_scale, self.weight_scale) output = output.view(input_shape[:-1] + (output.shape[-1],)) return output @torch.no_grad() def fuse_atten(model: nn.Module): """fuse qkv linear, fuse scaled_dot_product_attention if torch>=1.13""" for name, mod in model.named_modules(): if isinstance(mod, LlamaAttention): _, parent_mod, last_name = get_submodule_with_parent_from_name(model, name) fused_attn = LlamaAttentionFused(mod) setattr(parent_mod, last_name, fused_attn) @torch.no_grad() def quant_fc(model: nn.Module, quant_mod='weight8', fuse_qkv=False): """convert all fcs of LLM model to quantized linear inplace. Args: model: the Given LLM model. quant_mod: the working quantization mode. Default to be 'weight8', Optional:['weight4', 'dynamic_token']. The 'dynamic_token' quantization use easyquant lib to do Int8Gemm accelerate. fuse_qkv: whether to fuse qkv linear of attention to speedup inference, the scaled-dot-product-attention will be fusedif the PyTorch version >= 1.13. """ model.cpu() log.info(f'use quant mod {quant_mod} speedup={SPEEDUP}') if fuse_qkv: fuse_atten(model) log.info('qkv has been fused') for name, mod in model.named_modules(): if 'lm_head' in name: continue if isinstance(mod, nn.Linear): _, parent_mod, last_name = get_submodule_with_parent_from_name(model, name) if quant_mod == 'dynamic' and SPEEDUP: quantized_fc_cls = TDQLinear_noinit else: quantized_fc_cls = QLinear quantized_fc = quantized_fc_cls( mod, quant_mod, ) setattr(parent_mod, last_name, quantized_fc)