maga_transformer/tools/quant/fp8_quanter.py (217 lines of code) (raw):

import copy import json import time from typing import Dict, List import torch import os import logging from transformers import AutoModelForCausalLM, AutoConfig import safetensors from maga_transformer.tools.quant.base_quanter import QUANT_TYPE, BaseQuanter ''' FP8_DEFAULT_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, "*input_quantizer": {"num_bits": (4, 3), "axis": None}, "*block_sparse_moe.gate*": {"enable": False}, # Skip the MOE router "default": {"num_bits": (4, 3), "axis": None}, }, "algorithm": "max", } ''' ''' KV_CACHE_CFG = { "*.query_key_value.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.Wqkv.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.W_pack.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.c_attn.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.k_proj.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.v_proj.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, } ''' class Fp8Quanter(BaseQuanter): FP8_DEFAULT_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, "*input_quantizer": {"num_bits": (4, 3), "axis": None}, "*block_sparse_moe.gate*": {"enable": False}, # Skip the MOE router "default": {"num_bits": (4, 3), "axis": None}, }, "algorithm": "max", } KV_CACHE_CFG = { "*.query_key_value.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.Wqkv.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.W_pack.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.c_attn.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.k_proj.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.v_proj.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, } def __init__(self, quantize_config: Dict[str, str], model_path: str, offload_folder: str): super().__init__() self.quantize_config = quantize_config max_memory = {} per_gpu_max_memory = int(torch.cuda.get_device_properties(torch.device('cuda:0')).total_memory*0.95/1024/1024/1024) cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None) cuda_device_list = cuda_devices.split(',') if cuda_devices is not None else \ [str(i) for i in range(torch.cuda.device_count())] max_memory.update({int(i): f'{per_gpu_max_memory}GIB' for i in range(len(cuda_device_list))}) logging.info(f'max_memory: {max_memory}') model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype="auto", low_cpu_mem_usage=True, trust_remote_code=True, offload_folder=offload_folder, max_memory=max_memory) self.model = model.eval().half() self.quant_cfg = copy.deepcopy(Fp8Quanter.FP8_DEFAULT_CFG) kv_cache_dtype = quantize_config.get('kv_cache_dtype', None) if kv_cache_dtype is not None: if kv_cache_dtype == "fp8": for value in Fp8Quanter.KV_CACHE_CFG.values(): value.update({"num_bits": (4, 3)}) # type: ignore self.quant_cfg["quant_cfg"].update(Fp8Quanter.KV_CACHE_CFG) # type: ignore def _quant(self, examples: List[Dict[str, torch.Tensor]]): examples = [_.get('input_ids').tolist() for _ in examples] import modelopt.torch.quantization as atq def calibrate_loop(): if examples is None: return """Adjusts weights and scaling factors based on selected algorithms.""" for idx, example in enumerate(examples): print(f"Calibrating batch {idx}") example = torch.cat(example, dim=0) # model might be mapped to different device because the device_map is auto self.model(example.to(next(self.model.parameters()).device)) print("Starting quantization...") start_time = time.time() atq.quantize(self.model, self.quant_cfg, forward_loop=calibrate_loop) end_time = time.time() print("Quantization done. Total time used: {:.2f} s.".format(end_time - start_time)) @classmethod def quant_type(cls): return QUANT_TYPE.FP8 def _save_quantized(self, output_path: str): with torch.inference_mode(): if model_type is None: print( f"Unknown model type {type(model).__name__}. Continue exporting..." ) model_type = f"unknown:{type(model).__name__}" export_path = output_path start_time = time.time() from modelopt.torch.export import export_tensorrt_llm_checkpoint export_tensorrt_llm_checkpoint(self.model, model_type, getattr(torch, dtype), export_dir=export_path, inference_tensor_parallel=1, inference_pipeline_parallel=1) with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) # Workaround for MOE router quantization if "moe_num_experts" in tensorrt_llm_config and qformat != "full_prec": if "exclude_modules" not in tensorrt_llm_config["quantization"]: # Append router and lm_head because we need both excluded tensorrt_llm_config["quantization"]["exclude_modules"] = [ "router", "lm_head" ] else: tensorrt_llm_config["quantization"]["exclude_modules"].append( "router") with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) # Workaround for Modelopt 0.9.x fp8_kv_cache knob issue if qformat == 'fp8' and kv_cache_dtype is None: with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) tensorrt_llm_config["quantization"]["kv_cache_quant_algo"] = None with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) # Workaround for share_embedding_table if pp_size == 1: with safetensors.safe_open(f"{export_path}/rank0.safetensors", framework='pt', device='cpu') as f: share_embedding_table = 'lm_head.weight' not in f.keys() if share_embedding_table: with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) tensorrt_llm_config["share_embedding_table"] = True with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) # Workaround for gpt2 position embedding if model_type == 'gpt2': for rank in range(tp_size): weights = {} with safetensors.safe_open( f"{export_path}/rank{rank}.safetensors", framework='pt', device='cpu') as f: for key in f.keys(): weights[key] = f.get_tensor(key) if 'transformer.positional_embedding.weight' in weights: weights[ 'transformer.position_embedding.weight'] = weights.pop( 'transformer.positional_embedding.weight') safetensors.torch.save_file( weights, f"{export_path}/rank{rank}.safetensors") # Workaround for qwen version if model_type == 'qwen': with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) qwen_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) tensorrt_llm_config["qwen_type"] = qwen_config.model_type tensorrt_llm_config[ "intermediate_size"] = qwen_config.intermediate_size with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) torch.cuda.empty_cache( ) # otherwise torch is keeping using GPU, other routine like build engine has less free GPU to use end_time = time.time() print( "Quantized model exported to {} \nTotal time used {:.2f} s.".format( export_path, end_time - start_time))