transformers/llm/export/llmexport.py (963 lines of code) (raw):

import os import json import glob import base64 import warnings import argparse warnings.filterwarnings("ignore") os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' import onnx import torch from typing import Optional, List from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer from utils.spinner import spinner_run from utils.custom_op import FakeLinear from utils.onnx_rebuilder import OnnxRebuilder from utils.mnn_converter import MNNConveter from utils.awq_quantizer import AwqQuantizer from utils.model_mapper import ModelMapper from utils.transformers import Embedding, Rotary, Decoder, Lm class LlmExporter(torch.nn.Module): ''' Base class for all llm model export. Inherits from [`torch.nn.Module`]. ''' def __init__(self, args): super().__init__() self.init_from_args(args) self.load_model(args.path) def init_from_args(self, args): self.visual = None self.audio = None self.talker = None self.args = args self.max_length = 1024 self.stop_ids = [] self.dst_name = 'llm' # load config from args self.onnx_path = os.path.join(self.args.dst_path, 'onnx') if self.args.tokenizer_path is None: self.args.tokenizer_path = self.args.path if args.lm_quant_bit is None: self.args.lm_quant_bit = self.args.quant_bit # init export dst dir if not os.path.exists(self.args.dst_path): os.makedirs(self.args.dst_path) if not os.path.exists(self.onnx_path): os.makedirs(self.onnx_path) def load_pretrained(self, model_path: str): self.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer_path, trust_remote_code=True, use_fast=False) if 'Qwen2.5-Omni' in model_path: from transformers import Qwen2_5OmniForConditionalGeneration self.model = Qwen2_5OmniForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto").eval() elif 'Qwen2.5-VL' in model_path or 'Qwen2___5-VL' in model_path: from transformers import Qwen2_5_VLForConditionalGeneration self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype='auto').eval() elif 'Qwen2-VL' in model_path: from transformers import Qwen2VLForConditionalGeneration self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype='auto').eval() elif 'Qwen2-Audio' in model_path: from transformers import Qwen2AudioForConditionalGeneration self.audio = Qwen2AudioForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto") self.model = self.audio.language_model elif 'Llama-3.2' in model_path and 'Vision' in model_path: from transformers import MllamaForConditionalGeneration self.model = MllamaForConditionalGeneration.from_pretrained(model_path, torch_dtype='auto').eval() elif 'Llama' in model_path or 'Yi' in model_path: from transformers import LlamaForCausalLM self.model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval() elif 'InternVL' in model_path: self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float32, trust_remote_code=True).eval() else: try: self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval() except: self.model = AutoModel.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval() self.config = self.model.config if self.args.lora_path is not None and not self.args.lora_split: from peft import PeftModel adapter = PeftModel.from_pretrained(self.model, model_id=self.args.lora_path) self.model = adapter.merge_and_unload(progressbar=True) @staticmethod def has_attr(obj, attr): return hasattr(obj, attr) and getattr(obj, attr) is not None @spinner_run(f'load pretrained model ', True) def load_model(self, model_path): self.load_pretrained(model_path) self.attention_mask_type = 'float' # load tokenizer info self.stop_ids.append(self.tokenizer.eos_token_id) if hasattr(self.tokenizer, 'im_end_id'): self.stop_ids.append(self.tokenizer.im_end_id) try: eot_id = self.tokenizer.encode('<|eot_id|>') if len(eot_id) == 1: self.stop_ids.append(eot_id[0]) # gemma/gemma-2 eot_id = self.tokenizer.encode('<end_of_turn>') if len(eot_id) == 2 and eot_id[0] == 2: self.stop_ids.append(eot_id[1]) except: pass if hasattr(self.model, 'generation_config') and self.model.generation_config is not None: eos_token_id = self.model.generation_config.eos_token_id from collections.abc import Iterable if isinstance(eos_token_id, int): self.stop_ids.append(eos_token_id) elif isinstance(eos_token_id, Iterable): for id in eos_token_id: self.stop_ids.append(id) self.stop_ids = [stop_id for stop_id in self.stop_ids if stop_id is not None] self.stop_ids = list(set(self.stop_ids)) model_mapper = ModelMapper() self.tie_word_embeddings = self.args.tie_embed and (hasattr(self.config, 'tie_word_embeddings') and self.config.tie_word_embeddings) self.model_type, self.model_map = model_mapper.get_map(self.config) if self.args.awq: self.model.float() if self.args.export is not None: # set norm's weight as float for export def visit_module(module): if not isinstance(module, torch.nn.Linear) and hasattr(module, 'weight'): module.float() for name, child in module.named_children(): visit_module(child) visit_module(self.model) # print(self.model_type, self.model_map) # print(self.config, self.model_type, self.model_map, self.model) # print(self.model.model.layers[0].input_layernorm.weight); exit(0) # load config info ModelMapper.do_map(self, self.config, self.model_map['config']) if not hasattr(self, 'num_key_value_heads') or self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads if not hasattr(self, 'rope_theta') or self.rope_theta is None: self.rope_theta = 10000.0 if not hasattr(self, 'rope_ratio') or self.rope_ratio is None: self.rope_ratio = 1.0 if not hasattr(self, 'head_dim') or self.head_dim is None: if isinstance(self.num_attention_heads, list): self.head_dim = [self.hidden_size // atten_head for atten_head in self.num_attention_heads] else: self.head_dim = self.hidden_size // self.num_attention_heads # some export info if isinstance(self.num_attention_heads, list): self.past_kv_shape = [self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads[0], self.head_dim] else: self.past_kv_shape = [self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads, self.head_dim] self.block_dynamic_axes = { "inputs_embeds" : { 0: "seq_len" }, "attention_mask" : { 2: "seq_len", 3: "seq_len" }, "position_ids" : { 0: "seq_len" }, "past_key_values" : { 1: "history_len" } } self.model_dynamic_axes = { "input_ids" : { 0: "seq_len" }, "attention_mask" : { 2: "seq_len", 3: "seq_len" }, "position_ids" : { 1: "seq_len" }, "past_key_values" : { 3: "history_len" } } prompt_template = self.build_prompt_template() self.llm_config = { 'hidden_size' : self.hidden_size, 'layer_nums' : self.num_hidden_layers, 'attention_mask': self.attention_mask_type, 'key_value_shape': self.past_kv_shape[1:], "bos": prompt_template['bos'], "system_prompt_template": prompt_template['system'].format(content='%s'), 'user_prompt_template': prompt_template['user'].format(content='%s'), 'assistant_prompt_template': prompt_template['assistant'].format(content='%s'), 'is_visual': False } # load modules ModelMapper.do_map(self, self.model, self.model_map['model']) # rebuild modules if self.lm_ is None: out_features, in_features = self.embed_.weight.shape self.lm_ = torch.nn.Linear(in_features, out_features) self.lm_.weight = self.embed_.weight elif not isinstance(self.lm_, torch.nn.Linear): # for Baichuan2 weight = self.lm_.weight out_features, in_features = weight.shape self.lm_ = torch.nn.Linear(in_features, out_features) self.lm_.weight = weight self.lm_.bias.data = torch.zeros(out_features, dtype=weight.dtype) if self.embed_.weight is self.lm_.weight: import copy embed_copy = copy.deepcopy(self.embed_) self.embed = Embedding(embed_copy, self) else: self.embed = Embedding(self.embed_, self) # Rotary self.rotary = Rotary(self) self.blocks = [] for block in self.blocks_.children(): layer_id = len(self.blocks) self.blocks.append(Decoder(block, layer_id, self)) self.lm = Lm(self.lm_, self.final_layernorm_, self) # visual model if self.visual is not None: if self.args.export is not None: self.visual.float() from utils.vision import Vision self.visual = Vision.get_vision(self.model_type)(self.visual, self) if hasattr(self, 'audio') and self.audio is not None: from utils.audio import Audio self.audio = Audio.get_audio(self.audio.config.model_type)(self.audio, self) else: self.audio = None # talker model if hasattr(self, 'talker') and self.talker is not None and \ hasattr(self, 'token2wav') and self.token2wav is not None: from utils.talker import Talker self.talker = Talker.get_talker(self.model_type)(self.talker, self.token2wav, self) return model_path def get_attention_mask(self) -> torch.Tensor: if self.model_type == 'chatglm': return self.chatglm_attention_mask() if self.token_len: return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32) return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min def get_position_ids(self, input_ids = None) -> torch.Tensor: if self.visual is not None and hasattr(self.visual, 'get_position_ids') and callable(getattr(self.visual, 'get_position_ids')): return self.visual.get_position_ids(input_ids, self.seq_len, self.token_len) if self.model_type == 'chatglm': return self.chatglm_position_ids() if self.token_len: return torch.tensor([[self.seq_len - 1]], dtype=torch.int) return torch.arange(self.seq_len, dtype=torch.int).unsqueeze(0) def chatglm_attention_mask(self): if self.token_len: return torch.zeros([1]).bool().reshape([1, 1, 1, 1]) attention_mask = torch.zeros([self.seq_len, self.seq_len], dtype=torch.bool) for i in range(self.seq_len - 1): attention_mask[i][-1] = True attention_mask = attention_mask.reshape([1, 1, self.seq_len, self.seq_len]) return attention_mask def chatglm_position_ids(self): if self.token_len: return torch.tensor([self.context_len, self.token_len + 1]).reshape([1, 2, 1]) position_ids_0 = torch.arange(self.seq_len, dtype=torch.int) position_ids_1 = torch.zeros(self.seq_len, dtype=torch.int) position_ids_0[-1] = position_ids_0[-2] position_ids_1[-1] = 1 position_ids = torch.stack([position_ids_0, position_ids_1]).view(1, 2, -1) return position_ids def visual_embed(self, input_ids): return self.visual.embed(input_ids) def audio_embed(self, input_ids): return self.audio.embed(input_ids) def embedding(self, input_ids): if self.visual is not None and self.token_len == 0: input_embeds = self.visual_embed(input_ids) elif self.audio is not None and self.token_len == 0: input_embeds = self.audio_embed(input_ids) else: input_embeds = self.embed(input_ids) return input_embeds def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, past_key_values: Optional[List[torch.Tensor]] = None, logits_index: int = -1, cross_attention_states: Optional[torch.Tensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, ): hidden_states = input_ids # llm forward without embedding if self.model_type == 'gemma': normalizer = torch.tensor(self.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer presents = [None for i in range(self.num_hidden_layers)] rotary_pos_emb = self.rotary(position_ids) if self.args.test and rotary_pos_emb.dtype != hidden_states.dtype: rotary_pos_emb = rotary_pos_emb.type(hidden_states.dtype) for i in range(self.num_hidden_layers): if self.blocks[i].cross_decoder and cross_attention_states is None: continue hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, past_key_values[i]) presents[i] = kv talker_embeds = self.final_layernorm_(hidden_states) + input_ids.permute([1, 0, 2]) if hasattr(self, 'talker') and self.talker is not None: self.talker.add_talker_embeds(talker_embeds) logits = self.lm(hidden_states, logits_index) if presents[0].shape == presents[-1].shape and None not in presents: presents = torch.stack(presents) self.seq_len += 1 self.token_len += 1 return logits, presents, talker_embeds # some test functions def build_prompt_template(self): template = { 'bos': '', 'system': '{content}', 'user': '{content}', 'assistant': '{content}', } if self.model_type == 'baichuan': template['user'] = '<reserved_106>{content}' template['assistant'] = '<reserved_107>{content}' if self.model_type == 'chatglm': template['user'] = '{content}[gMASK]<sop>' if self.model_type == 'chatglm2' and 'codegeex' not in self.args.path: template['user'] = '[Round 1]\n\n问:{content}\n\n' template['assistant'] = '答:{content}\n\n' if 'chatglm3' in self.args.path or 'glm-4' in self.args.path: template['bos'] = '[gMASK]<sop>' template['system'] = '<|system|>\n{content}\n' template['user'] = '<|user|>\n{content}\n' template['assistant'] = '<|assistant|>\n{content}\n' if self.model_type == 'llama': if 'deepseek' in self.args.path: template['bos'] = '<|begin_of_sentence|>' template['system'] = '{content}\n' template['user'] = '\nUser: {content}\n' template['assistant'] = '\nAssistant: {content}\n<|end_of_sentence|>' if 'Llama-2' in self.args.path: template['bos'] = '[INST] ' template['system'] = "<<SYS>>\n{content}\n<</SYS>>\n\n" template['user'] = '{content} [/INST]' template['assistant'] = "{content}</s>"; if 'Llama-3' in self.args.path: template['system'] = '<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>' template['user'] = '<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>' template['assistant'] = '<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>' if 'TinyLlama' in self.args.path: template['bos'] = '<s>' template['system'] = '<|system|>\n{content}</s>\n' template['user'] = '<|user|>\n{content}</s>\n' template['assistant'] = '<|assistant|>\n{content}</s>\n' if 'Yi' in self.args.path or 'SmolLM2' in self.args.path: template['system'] = '<|im_start|>system\n{content}<|im_end|>\n' template['user'] = '<|im_start|>user\n{content}<|im_end|>\n' template['assistant'] = '<|im_start|>assistant\n{content}<|im_end|>\n' if self.model_type == 'gemma2': template['bos'] = '<bos>' template['system'] = '<start_of_turn>system\n{content}<end_of_turn>\n' template['user'] = '<start_of_turn>user\n{content}<end_of_turn>\n' template['assistant'] = '<start_of_turn>model\n{content}<end_of_turn>\n' if self.model_type == 'gemma': template['bos'] = '<bos>' if self.model_type == 'internlm': template['user'] = '<|User|>:{content}<eoh>\n' template['assistant'] = '<|Bot|>:{content}<eoh>\n' if self.model_type == 'phi-msft': template['user'] = 'Instruct: {content}\n' template['assistant'] = 'Output:{content}\n' if self.model_type == 'openelm': template['bos'] = '<s>' if self.model_type == 'internvl_chat': if 'Qwen' in self.config.llm_config._name_or_path: print("[DEBUG] Use qwen prompt template") template['system'] = '<|im_start|>system\n{content}<|im_end|>\n' template['user'] = '<|im_start|>user\n{content}<|im_end|>\n' template['assistant'] = '<|im_start|>assistant\n{content}<|im_end|>\n' if 'qwen' in self.model_type: template['system'] = '<|im_start|>system\n{content}<|im_end|>\n' template['user'] = '<|im_start|>user\n{content}<|im_end|>\n' template['assistant'] = '<|im_start|>assistant\n{content}<|im_end|>\n' if 'DeepSeek' in self.args.path: template['bos'] = '<|begin_of_sentence|>' template['user'] = '<|User|>{content}' template['assistant'] = '<|Assistant|>{content}<|end_of_sentence|>' return template def build_prompt(self, messages): template = self.build_prompt_template() prompt = template['bos'] for item in messages: role, content = item['role'], item['content'] if '{content}' in template[role]: prompt += template[role].format(content=content) else: prompt += role + '\n' + content +'\n' assistant_prefix = template['assistant'].split('{content}')[0] return prompt + assistant_prefix def str_to_ids(self, prompt): if self.visual is not None: return self.visual.str_to_ids(prompt) if self.audio is not None: return self.audio.str_to_ids(prompt) input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids'] return input_ids def id_to_str(self, token_id): try: word = self.tokenizer.decode(int(token_id)) except: def contains_replacement(text): return '\uFFFD' in text def decode_id(token_id): return self.tokenizer.convert_tokens_to_string( self.tokenizer._convert_id_to_token(int(token_id))) def decode_ids(token_ids): return self.tokenizer.convert_tokens_to_string( self.tokenizer.convert_ids_to_tokens(token_ids)) word = decode_id(int(token_id)) # Smollm tokenizer will produce half chinese character, using buffer to decode if contains_replacement(word): self.decode_buffer.append(token_id) buffer_txt = decode_ids(self.decode_buffer) if not contains_replacement(buffer_txt): word = buffer_txt self.decode_buffer.clear() else: word = '' return word @torch.no_grad() def response(self, query): # self.imitate_quant() self.decode_buffer = [] messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": query} ] prompt = self.build_prompt(messages) input_ids = self.str_to_ids(prompt) if self.visual is not None: cross_attention_states = self.visual.cross_attention_states cross_attention_mask = self.visual.cross_attention_mask else: cross_attention_states = None cross_attention_mask = None self.seq_len = input_ids.numel() self.context_len = self.seq_len - 2 self.token_len = 0 past_key_values = [None for i in range(self.num_hidden_layers)] token_id = input_ids while self.token_len < self.max_length: attention_mask = self.get_attention_mask() position_ids = self.get_position_ids(token_id) input_ids = self.embedding(token_id) logits, past_key_values, _ = self.forward(input_ids, attention_mask, position_ids, past_key_values, cross_attention_states, cross_attention_mask) token_id = torch.argmax(logits[:,-1,:]) if token_id in self.stop_ids: print("", end='\n') break word = self.id_to_str(token_id) print(word, end="", flush=True) if hasattr(self, 'talker') and self.talker is not None: self.talker.generate() @spinner_run(f'export embedding to ') def export_embed(self): import ctypes if hasattr(self, 'word_embeddings'): # embedding model's embed tensor_data = self.word_embeddings.weight.data.bfloat16() else: tensor_data = self.embed.embed.weight.data.bfloat16() data_ptr = tensor_data.untyped_storage().data_ptr() buffer = (ctypes.c_byte * (tensor_data.numel() * 2)).from_address(data_ptr) embedding_file = f'{self.args.dst_path}/embeddings_bf16.bin' with open(embedding_file, 'wb') as f: f.write(buffer) return embedding_file @spinner_run(f'export config to ') def export_config(self, mnn_config = False): config_json = f'{self.args.dst_path}/llm_config.json' with open(config_json, 'w', encoding='utf-8') as f: json.dump(self.llm_config, f, ensure_ascii=False, indent=4) if not mnn_config: return config_json with open(f'{self.args.dst_path}/config.json', 'w', encoding='utf-8') as f: config = { "llm_model": f"{self.dst_name}.mnn", "llm_weight": f"{self.dst_name}.mnn.weight", "backend_type": "cpu", "thread_num": 4, "precision": "low", "memory": "low", "system_prompt": "You are a helpful assistant.", } if self.talker is not None: config['system_prompt'] = "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." config['talker_max_new_tokens'] = 2048 config['talker_speaker'] = "Chelsie" config['dit_steps'] = 5 config['dit_solver'] = 1 if self.visual is not None or self.audio is not None: config['mllm'] = { 'backend_type': "cpu", "thread_num": 4, "precision": "normal", "memory": "low" } json.dump(config, f, ensure_ascii=False, indent=4) return config_json def imitate_quant(self): def quant_dequant(linear, quant_bit = self.args.quant_bit, quant_block = self.args.quant_block): weight = linear.weight.data oc, ic = weight.shape if quant_block == 0: block_size = ic else: block_size = quant_block block_num = ic // block_size weight = weight.reshape(oc, block_num, block_size) max_val = torch.max(weight, axis=-1, keepdims=True).values min_val = torch.min(weight, axis=-1, keepdims=True).values offset = 1 << (quant_bit - 1) clip_max = offset - 1 clip_min = -offset scale = (max_val - min_val) / (clip_max - clip_min) q_weight = torch.round((weight - min_val) / scale) + clip_min q_weight = torch.clip(q_weight, clip_min, clip_max) dq_weight = (q_weight - clip_min) * scale + min_val dq_weight = dq_weight.reshape(oc, ic).float() linear.weight.data = dq_weight return linear with torch.no_grad(): for i in range(self.num_hidden_layers): for name, child in self.blocks[i].self_attn.named_children(): if isinstance(child, torch.nn.Linear): setattr(self.blocks[i].self_attn, name, quant_dequant(child)) for name, child in self.blocks[i].mlp.named_children(): if isinstance(child, torch.nn.Linear): setattr(self.blocks[i].mlp, name, quant_dequant(child)) self.lm.lm = quant_dequant(self.lm.lm) def unload_param(self): self.unloaded_ops = {} def build_faker(real, name): faker = FakeLinear(real.in_features, real.out_features, real.bias is not None, name) self.unloaded_ops[name] = real return faker # replace linear with fakelinear to save export memory and time with torch.no_grad(): for i in range(self.num_hidden_layers): # different kv cache shape in different layers if isinstance(self.num_attention_heads, list): self.blocks[i].self_attn.export_fused_attn = True for name, child in self.blocks[i].self_attn.named_children(): if isinstance(child, torch.nn.Linear): setattr(self.blocks[i].self_attn, name, build_faker(child, f'/layers.{i}/self_attn/{name}/Linear')) for name, child in self.blocks[i].mlp.named_children(): if isinstance(child, torch.nn.Linear): setattr(self.blocks[i].mlp, name, build_faker(child, f'/layers.{i}/mlp/{name}/Linear')) self.lm.lm = build_faker(self.lm.lm, f'/lm/lm_head/Linear') @spinner_run(f'export model weight to ') def onnx_load_param(self, onnx_path): return OnnxRebuilder(onnx_path, self.unloaded_ops).rebuild() @spinner_run(f'slim the graph of ') def slim_onnx(self, onnx_model): import onnxslim model = onnxslim.slim(onnx_model) onnx.save(model, onnx_model) return onnx_model @spinner_run(f'export onnx model to ') def export_onnx(self): # unload linear weight to save export memory self.unload_param() model = self self.seq_len = 3 self.token_len = 0 input_ids = torch.arange(3, dtype=torch.long) attention_mask = self.get_attention_mask() position_ids = self.get_position_ids(input_ids) onnx_model = f'{self.onnx_path}/{self.dst_name}.onnx' # For export onnx, don't need image or audio's embedding input_ids = self.embed(input_ids) past_key_values = torch.zeros(self.past_kv_shape) logits_index = torch.tensor([-1], dtype=torch.int32) if hasattr(self, 'talker') and self.talker is not None: output_names = ['logits', 'presents', 'talker_embeds'] else: output_names = ['logits', 'presents'] # export to onnx torch.onnx.export( model, (input_ids, attention_mask, position_ids, past_key_values, logits_index), onnx_model, input_names=[ 'input_ids', 'attention_mask', 'position_ids', 'past_key_values', 'logits_index' ], output_names=output_names, dynamic_axes=self.model_dynamic_axes, do_constant_folding=True, verbose=False, opset_version=15) return onnx_model def awq_quant(self): self.awq_quantizer = AwqQuantizer(self) self.awq_quantizer.quantize() self.is_awq_quantized = True def export_vision(self): if self.visual is None: return vision_onnx = self.visual.export(self.onnx_path) if self.mnn_converter: self.mnn_converter.export(vision_onnx, self.visual.quant_bit) def export_audio(self): if self.audio is None: return audio_onnx = self.audio.export(self.onnx_path) if self.mnn_converter: self.mnn_converter.export(audio_onnx, self.audio.quant_bit) def export_talker(self): if self.talker is None: return talker_onnx = self.talker.export(self.onnx_path) predit_onnx, dit_onnx, bigvgan_onnx = self.talker.token2wav.export(self.onnx_path) if self.mnn_converter: self.mnn_converter.export(talker_onnx, self.talker.quant_bit) self.mnn_converter.export(predit_onnx, self.talker.token2wav.quant_bit) self.mnn_converter.export(dit_onnx, self.talker.token2wav.quant_bit) self.mnn_converter.export(bigvgan_onnx, self.talker.token2wav.quant_bit) def export_language(self): # export_embedding if self.mnn_converter and self.tie_word_embeddings: pass # mnn tie_word_embeddings need't export embedding else: self.export_embed() # export transformer onnx_model = self.export_onnx() if self.args.onnx_slim: self.slim_onnx(onnx_model) if self.mnn_converter: MNNConveter(self, self.unloaded_ops).export(onnx_model) else: self.onnx_load_param(onnx_model) def export(self, export_type): if self.args.awq: self.awq_quant() export_mnn = export_type == 'mnn' self.mnn_converter = MNNConveter(self) if export_mnn else None self.export_talker() self.export_vision() self.export_audio() self.export_language() self.export_tokenizer() self.export_config(export_mnn) if export_mnn: # delete onnx file try: for file in glob.glob(f'{self.onnx_path}/*'): os.remove(file) os.rmdir(self.onnx_path) except Exception as e: print(f"remove onnx error: {e}") @spinner_run(f'export tokenizer to ') def export_tokenizer(self): # load tokenizer file tokenizer_model = os.path.join(self.args.tokenizer_path, 'tokenizer.model') ice_text_model = os.path.join(self.args.tokenizer_path, 'ice_text.model') try: import sentencepiece as spm if os.path.exists(tokenizer_model): self.sp_model = spm.SentencePieceProcessor(tokenizer_model) elif os.path.exists(ice_text_model): self.sp_model = spm.SentencePieceProcessor(ice_text_model) else: self.sp_model = None except: self.sp_model = None merge_file = os.path.join(self.args.path, 'merges.txt') if os.path.exists(merge_file): self.merge_txt = merge_file else: self.merge_txt = None # TOKENIZER MAGIC NUMBER MAGIC_NUMBER = 430 # TOKENIZER TYPE SENTENCEPIECE = 0; TIKTOIKEN = 1; BERT = 2; HUGGINGFACE = 3 def write_line(fp, *args): for arg in args: for token in arg: fp.write(str(token) + ' ') fp.write('\n') def write_header(fp, type, speicals, prefix = []): fp.write(f'{MAGIC_NUMBER} {type}\n') fp.write(f'{len(speicals)} {len(self.stop_ids)} {len(prefix)}\n') write_line(fp, speicals, self.stop_ids, prefix) file_path = os.path.join(self.args.dst_path, "tokenizer.txt") special_list = list(self.tokenizer.added_tokens_decoder.keys()) if hasattr(self.tokenizer, 'special_tokens'): for k, v in self.tokenizer.special_tokens.items(): special_list.append(v) if hasattr(self.tokenizer, 'gmask_token_id'): special_list.append(self.tokenizer.gmask_token_id) if hasattr(self.model, 'generation_config') and self.model.generation_config is not None: generation_config = self.model.generation_config if hasattr(generation_config, 'user_token_id'): special_list.append(generation_config.user_token_id) if hasattr(generation_config, 'assistant_token_id'): special_list.append(generation_config.assistant_token_id) vocab_list = [] prefix_list = [] if hasattr(self.tokenizer, 'get_prefix_tokens'): prefix_list = self.tokenizer.get_prefix_tokens() if len(prefix_list) == 0: try: test_txt = 'A' ids = self.tokenizer.encode(test_txt) get_txt = self.tokenizer.decode(ids[-1]) if len(ids) > 1 and get_txt == test_txt: prefix_list += ids[:-1] except: pass if self.sp_model is not None: # senetencepiece NORMAL = 1; UNKNOWN = 2; CONTROL = 3 USER_DEFINED = 4; UNUSED = 5; BYTE = 6 for i in range(self.sp_model.GetPieceSize()): token = self.sp_model.IdToPiece(i) score = self.sp_model.GetScore(i) token_type = NORMAL if self.sp_model.IsUnknown(i): token_type = UNKNOWN elif self.sp_model.IsControl(i): token_type = CONTROL elif self.sp_model.IsUnused(i): token_type = UNUSED elif self.sp_model.IsByte(i): token_type = BYTE if self.args.path == 'Chatglm_6b': if '<n>' in token: token = '\n' if '<|tab|>' in token: token = '\t' if '<|blank_' in token: token = ' ' * int(token[8:token.find('|>')]) if '▁' in token: token = token.replace('▁', ' ') token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8") vocab_list.append(f'{token_encode} {score} {token_type}\n') with open(file_path, "w", encoding="utf8") as fp: write_header(fp, SENTENCEPIECE, special_list, prefix_list) fp.write(f'{len(vocab_list)}\n') for vocab in vocab_list: fp.write(vocab) elif hasattr(self.tokenizer, 'mergeable_ranks'): # tikton vocab_list = [] for k, v in self.tokenizer.mergeable_ranks.items(): line = base64.b64encode(k).decode("utf8") + "\n" vocab_list.append(line) if hasattr(self.tokenizer, 'special_tokens'): for k, v in self.tokenizer.special_tokens.items(): line = base64.b64encode(k.encode("utf-8")).decode("utf8") + "\n" vocab_list.append(line) if hasattr(self.tokenizer, 'added_tokens_decoder'): for k, v in self.tokenizer.added_tokens_decoder.items(): line = base64.b64encode(v.__str__().encode("utf-8")).decode("utf8") + "\n" vocab_list.append(line) with open(file_path, "w", encoding="utf8") as fp: write_header(fp, TIKTOIKEN, special_list, prefix_list) fp.write(f'{len(vocab_list)}\n') for vocab in vocab_list: fp.write(vocab) elif self.merge_txt is not None: # huggingface tokenizer merge_list = [] vocab = self.tokenizer.get_vocab() special_list = list(self.tokenizer.added_tokens_decoder.keys()) vocab_list = ['<unk>' for i in range(len(vocab))] # load vocab for k, v in vocab.items(): vocab_list[int(v)] = k # load merge with open(self.merge_txt, 'rt') as merge: for line in merge.readlines(): merge_list.append(line) # write to tokenizer.txt with open(file_path, "w", encoding="utf8") as fp: write_header(fp, HUGGINGFACE, special_list) fp.write(f'{len(vocab_list)} {len(merge_list)}\n') for v in vocab_list: fp.write(v + '\n') for m in merge_list: fp.write(m) else: # tiktoken or bert if 'bert' in type(self.tokenizer).__name__.lower(): tokenizer_type = BERT else: tokenizer_type = TIKTOIKEN # bert tokenizer def unicode_to_byte(u: int): if u >= 256 and u <= 288: return u - 256 if u >= 289 and u <= 322: return u - 162 if u == 323: return 173 if u == 65372: # | return 124 if u == 9601: # _ return 95 return u vocab = self.tokenizer.get_vocab() vocab_list = ['<unk>' for i in range(len(vocab))] for k, v in vocab.items(): try: vocab_list[int(v)] = bytes([unicode_to_byte(ord(c)) for c in k]) except: vocab_list[int(v)] = k.encode('utf-8') special_list = list(self.tokenizer.added_tokens_decoder.keys()) with open(file_path, "w", encoding="utf8") as fp: write_header(fp, tokenizer_type, special_list) fp.write(f'{len(vocab_list)}\n') for v in vocab_list: line = base64.b64encode(v).decode("utf8") + "\n" fp.write(line) return file_path class EmbeddingExporter(LlmExporter): def __init__(self, args): super().__init__(args) self.dst_name = 'embedding' def word_embed(self, input_ids): return self.word_embeddings(input_ids.view(1, -1)) def bge_forward(self, inputs_embeds, position_ids, attention_mask): # bert absolute position inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size) position_embeddings = self.position_embeddings(position_ids) embeddings = inputs_embeds + position_embeddings + self.token_type_embeddings hidden_states = self.embedding_layernorm(embeddings) for i in range(self.num_hidden_layers): hidden_states = self.blocks[i](hidden_states, attention_mask)[0] sentence_embeddings = hidden_states[:, 0] sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) return sentence_embeddings def gte_forward(self, inputs_embeds, position_ids, attention_mask): # rope position inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size) freqs = position_ids.float().reshape(-1, 1) * self.inv_freq emb = torch.cat((freqs, freqs), dim=-1) rope_embeds = torch.stack([emb.cos(), emb.sin()]).unsqueeze(-2).unsqueeze(1) attention_bias = 1 - attention_mask.float() hidden_states = self.embedding_layernorm(inputs_embeds + self.token_type_embeddings) for i in range(self.num_hidden_layers): hidden_states = self.blocks[i](hidden_states, attention_bias, rope_embeds)[0] sentence_embeddings = hidden_states[:, 0] sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) return sentence_embeddings def forward(self, inputs_embeds, position_ids, attention_mask): if self.model_type == 'bert': return self.bge_forward(inputs_embeds, position_ids, attention_mask) if self.model_type == 'new': return self.gte_forward(inputs_embeds, position_ids, attention_mask) raise RuntimeError(f'Not support embedding model: {self.model_type}!') def response(self, query): self.eval() input_ids = self.tokenizer(query)['input_ids'] self.seq_len = len(input_ids) input_ids = torch.tensor(input_ids) position_ids = self.get_position_ids() attention_mask = self.get_attention_mask() inputs_embeds = self.word_embed(input_ids) res = self.forward(inputs_embeds, position_ids, attention_mask) # print(res) return res @spinner_run(f'load pretrained model ') def load_model(self, model_path): self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.config = AutoConfig.from_pretrained(model_path) self.config._attn_implementation = 'eager' self.model = AutoModel.from_config(self.config) transformer = self.model.encoder self.model_type = self.config.model_type self.lm_ = self.model.pooler self.embed_ = self.model.embeddings self.word_embeddings = self.embed_.word_embeddings self.token_type_embeddings = self.embed_.token_type_embeddings.weight.data[0] self.embedding_layernorm = self.embed_.LayerNorm if hasattr(self.embed_, 'position_embeddings'): self.position_embeddings = self.embed_.position_embeddings self.hidden_size = self.word_embeddings.weight.shape[-1] self.blocks = transformer.layer if self.model_type == 'new': self.inv_freq = self.embed_.rotary_emb.inv_freq # some wrapper self.stop_ids = [] self.num_hidden_layers = len(self.blocks) self.embed = self.embed_ self.lm = self.lm_ # some config for export self.model_dynamic_axes = { "input_ids" : { 1: "seq_len" }, "position_ids" : { 1: "seq_len" }, "attention_mask" : { 3: "seq_len" } } self.attention_mask_type = 'int' self.llm_config = { 'hidden_size' : self.hidden_size, 'layer_nums' : self.num_hidden_layers, 'attention_mask': self.attention_mask_type, 'key_value_shape': [], "prompt_template": self.build_prompt('%s'), 'is_visual': False } return model_path @spinner_run(f'export onnx model to ') def export_onnx(self): model = self.eval() self.seq_len = 3 input_ids = torch.arange(3, dtype=torch.long) position_ids = self.get_position_ids() attention_mask = self.get_attention_mask() inputs_embeds = self.word_embed(input_ids) onnx_model = f'{self.onnx_path}/{self.dst_name}.onnx' torch.onnx.export( model, (inputs_embeds, position_ids, attention_mask), onnx_model, input_names=[ 'input_ids', 'position_ids', 'attention_mask' ], output_names=['sentence_embeddings'], dynamic_axes=self.model_dynamic_axes, do_constant_folding=True, opset_version=15) return onnx_model def export(self, export_type): export_mnn = 'mnn' in export_type self.export_tokenizer() self.export_config(export_mnn) self.export_embed() onnx_model = self.export_onnx() if self.args.onnx_slim: self.slim_onnx(onnx_model) if export_mnn: MNNConveter(onnx_model, None, self).export() def build_prompt(self, content): if self.model_type == 'bert': return f'[CLS]{content}[SEP]' if self.model_type == 'new': return f'<s> {content}</s>' def get_position_ids(self) -> torch.Tensor: return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0) def get_attention_mask(self) -> torch.Tensor: return torch.ones([1, 1, 1, self.seq_len], dtype=torch.long) def export(path, type = None, tokenizer_path = None, lora_path = None, gptq_path = None, dst_path = './model', export = 'onnx', onnx_slim = False, quant_bit = 4, quant_block = 128, lm_quant_bit = None, mnnconvert = None, ppl = False, awq = False, sym = False, tie_embed = False, lora_split = False): args = argparse.Namespace() for k, v in { 'path': path, 'type': type, 'tokenizer_path': tokenizer_path, 'lora_path': lora_path, 'gptq_path': gptq_path, 'dst_path': dst_path, 'export': export, 'onnx_slim': onnx_slim, 'quant_bit': quant_bit, 'quant_block': quant_block, 'lm_quant_bit': lm_quant_bit, 'mnnconvert': mnnconvert, 'ppl': ppl, 'awq': awq, 'sym': sym, 'tie_embed': tie_embed, 'lora_split': lora_split }.items(): setattr(args, k, v) if 'bge' in path: llm_exporter = EmbeddingExporter(args) else: llm_exporter = LlmExporter(args) # export llm_exporter.export(export) def main(): parser = argparse.ArgumentParser(description='llm_exporter', formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('--path', type=str, required=True, help='path(`str` or `os.PathLike`):\nCan be either:' '\n\t- A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO]' '\n\t- A path to a *directory* clone from repo like `../chatglm-6b`.') parser.add_argument('--type', type=str, default=None, help='type(`str`, *optional*):' '\n\tThe pretrain llm model type.' ) parser.add_argument('--tokenizer_path', type=str, default=None, help='tokenizer path, defaut is `None` mean using `--path` value.') parser.add_argument('--lora_path', type=str, default=None, help='lora path, defaut is `None` mean not apply lora.') parser.add_argument('--gptq_path', type=str, default=None, help='gptq path, defaut is `None` mean not apply gptq.') parser.add_argument('--dst_path', type=str, default='./model', help='export onnx/mnn model to path, defaut is `./model`.') parser.add_argument('--verbose', action='store_true', help='Whether or not to print verbose.') parser.add_argument('--test', type=str, help='test model inference with query `TEST`.') parser.add_argument('--export', type=str, default=None, help='export model to an onnx/mnn model.') parser.add_argument('--onnx_slim', action='store_true', help='Whether or not to use onnx-slim.') parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.') parser.add_argument('--quant_block', type=int, default=128, help='mnn quant block, default is 0 mean channle-wise.') parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.') parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.') parser.add_argument('--ppl', action='store_true', help='Whether or not to get all logits of input tokens.') parser.add_argument('--awq', action='store_true', help='Whether or not to use awq quant.') parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), defualt is False.') parser.add_argument('--tie_embed', action='store_true', help='Whether or not to using tie_embedding, defualt is False.') parser.add_argument('--lora_split', action='store_true', help='Whether or not export lora split, defualt is False.') args = parser.parse_args() model_path = args.path model_type = args.type if 'gte' in model_path or 'bge' in model_path: llm_exporter = EmbeddingExporter(args) else: llm_exporter = LlmExporter(args) # some actions if args.test is not None: llm_exporter.response(args.test) if args.export is not None: llm_exporter.export(args.export) if __name__ == '__main__': main()