transformers/llm/export/llmexport.py (963 lines of code) (raw):
import os
import json
import glob
import base64
import warnings
import argparse
warnings.filterwarnings("ignore")
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
import onnx
import torch
from typing import Optional, List
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from utils.spinner import spinner_run
from utils.custom_op import FakeLinear
from utils.onnx_rebuilder import OnnxRebuilder
from utils.mnn_converter import MNNConveter
from utils.awq_quantizer import AwqQuantizer
from utils.model_mapper import ModelMapper
from utils.transformers import Embedding, Rotary, Decoder, Lm
class LlmExporter(torch.nn.Module):
'''
Base class for all llm model export. Inherits from [`torch.nn.Module`].
'''
def __init__(self, args):
super().__init__()
self.init_from_args(args)
self.load_model(args.path)
def init_from_args(self, args):
self.visual = None
self.audio = None
self.talker = None
self.args = args
self.max_length = 1024
self.stop_ids = []
self.dst_name = 'llm'
# load config from args
self.onnx_path = os.path.join(self.args.dst_path, 'onnx')
if self.args.tokenizer_path is None:
self.args.tokenizer_path = self.args.path
if args.lm_quant_bit is None:
self.args.lm_quant_bit = self.args.quant_bit
# init export dst dir
if not os.path.exists(self.args.dst_path):
os.makedirs(self.args.dst_path)
if not os.path.exists(self.onnx_path):
os.makedirs(self.onnx_path)
def load_pretrained(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer_path, trust_remote_code=True, use_fast=False)
if 'Qwen2.5-Omni' in model_path:
from transformers import Qwen2_5OmniForConditionalGeneration
self.model = Qwen2_5OmniForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto").eval()
elif 'Qwen2.5-VL' in model_path or 'Qwen2___5-VL' in model_path:
from transformers import Qwen2_5_VLForConditionalGeneration
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype='auto').eval()
elif 'Qwen2-VL' in model_path:
from transformers import Qwen2VLForConditionalGeneration
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype='auto').eval()
elif 'Qwen2-Audio' in model_path:
from transformers import Qwen2AudioForConditionalGeneration
self.audio = Qwen2AudioForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.model = self.audio.language_model
elif 'Llama-3.2' in model_path and 'Vision' in model_path:
from transformers import MllamaForConditionalGeneration
self.model = MllamaForConditionalGeneration.from_pretrained(model_path, torch_dtype='auto').eval()
elif 'Llama' in model_path or 'Yi' in model_path:
from transformers import LlamaForCausalLM
self.model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval()
elif 'InternVL' in model_path:
self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float32, trust_remote_code=True).eval()
else:
try:
self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval()
except:
self.model = AutoModel.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval()
self.config = self.model.config
if self.args.lora_path is not None and not self.args.lora_split:
from peft import PeftModel
adapter = PeftModel.from_pretrained(self.model, model_id=self.args.lora_path)
self.model = adapter.merge_and_unload(progressbar=True)
@staticmethod
def has_attr(obj, attr):
return hasattr(obj, attr) and getattr(obj, attr) is not None
@spinner_run(f'load pretrained model ', True)
def load_model(self, model_path):
self.load_pretrained(model_path)
self.attention_mask_type = 'float'
# load tokenizer info
self.stop_ids.append(self.tokenizer.eos_token_id)
if hasattr(self.tokenizer, 'im_end_id'):
self.stop_ids.append(self.tokenizer.im_end_id)
try:
eot_id = self.tokenizer.encode('<|eot_id|>')
if len(eot_id) == 1:
self.stop_ids.append(eot_id[0])
# gemma/gemma-2
eot_id = self.tokenizer.encode('<end_of_turn>')
if len(eot_id) == 2 and eot_id[0] == 2:
self.stop_ids.append(eot_id[1])
except:
pass
if hasattr(self.model, 'generation_config') and self.model.generation_config is not None:
eos_token_id = self.model.generation_config.eos_token_id
from collections.abc import Iterable
if isinstance(eos_token_id, int):
self.stop_ids.append(eos_token_id)
elif isinstance(eos_token_id, Iterable):
for id in eos_token_id:
self.stop_ids.append(id)
self.stop_ids = [stop_id for stop_id in self.stop_ids if stop_id is not None]
self.stop_ids = list(set(self.stop_ids))
model_mapper = ModelMapper()
self.tie_word_embeddings = self.args.tie_embed and (hasattr(self.config, 'tie_word_embeddings') and self.config.tie_word_embeddings)
self.model_type, self.model_map = model_mapper.get_map(self.config)
if self.args.awq:
self.model.float()
if self.args.export is not None:
# set norm's weight as float for export
def visit_module(module):
if not isinstance(module, torch.nn.Linear) and hasattr(module, 'weight'):
module.float()
for name, child in module.named_children():
visit_module(child)
visit_module(self.model)
# print(self.model_type, self.model_map)
# print(self.config, self.model_type, self.model_map, self.model)
# print(self.model.model.layers[0].input_layernorm.weight); exit(0)
# load config info
ModelMapper.do_map(self, self.config, self.model_map['config'])
if not hasattr(self, 'num_key_value_heads') or self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if not hasattr(self, 'rope_theta') or self.rope_theta is None:
self.rope_theta = 10000.0
if not hasattr(self, 'rope_ratio') or self.rope_ratio is None:
self.rope_ratio = 1.0
if not hasattr(self, 'head_dim') or self.head_dim is None:
if isinstance(self.num_attention_heads, list):
self.head_dim = [self.hidden_size // atten_head for atten_head in self.num_attention_heads]
else:
self.head_dim = self.hidden_size // self.num_attention_heads
# some export info
if isinstance(self.num_attention_heads, list):
self.past_kv_shape = [self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads[0], self.head_dim]
else:
self.past_kv_shape = [self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads, self.head_dim]
self.block_dynamic_axes = {
"inputs_embeds" : { 0: "seq_len" },
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
"position_ids" : { 0: "seq_len" },
"past_key_values" : { 1: "history_len" }
}
self.model_dynamic_axes = {
"input_ids" : { 0: "seq_len" },
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
"position_ids" : { 1: "seq_len" },
"past_key_values" : { 3: "history_len" }
}
prompt_template = self.build_prompt_template()
self.llm_config = {
'hidden_size' : self.hidden_size,
'layer_nums' : self.num_hidden_layers,
'attention_mask': self.attention_mask_type,
'key_value_shape': self.past_kv_shape[1:],
"bos": prompt_template['bos'],
"system_prompt_template": prompt_template['system'].format(content='%s'),
'user_prompt_template': prompt_template['user'].format(content='%s'),
'assistant_prompt_template': prompt_template['assistant'].format(content='%s'),
'is_visual': False
}
# load modules
ModelMapper.do_map(self, self.model, self.model_map['model'])
# rebuild modules
if self.lm_ is None:
out_features, in_features = self.embed_.weight.shape
self.lm_ = torch.nn.Linear(in_features, out_features)
self.lm_.weight = self.embed_.weight
elif not isinstance(self.lm_, torch.nn.Linear):
# for Baichuan2
weight = self.lm_.weight
out_features, in_features = weight.shape
self.lm_ = torch.nn.Linear(in_features, out_features)
self.lm_.weight = weight
self.lm_.bias.data = torch.zeros(out_features, dtype=weight.dtype)
if self.embed_.weight is self.lm_.weight:
import copy
embed_copy = copy.deepcopy(self.embed_)
self.embed = Embedding(embed_copy, self)
else:
self.embed = Embedding(self.embed_, self)
# Rotary
self.rotary = Rotary(self)
self.blocks = []
for block in self.blocks_.children():
layer_id = len(self.blocks)
self.blocks.append(Decoder(block, layer_id, self))
self.lm = Lm(self.lm_, self.final_layernorm_, self)
# visual model
if self.visual is not None:
if self.args.export is not None:
self.visual.float()
from utils.vision import Vision
self.visual = Vision.get_vision(self.model_type)(self.visual, self)
if hasattr(self, 'audio') and self.audio is not None:
from utils.audio import Audio
self.audio = Audio.get_audio(self.audio.config.model_type)(self.audio, self)
else:
self.audio = None
# talker model
if hasattr(self, 'talker') and self.talker is not None and \
hasattr(self, 'token2wav') and self.token2wav is not None:
from utils.talker import Talker
self.talker = Talker.get_talker(self.model_type)(self.talker, self.token2wav, self)
return model_path
def get_attention_mask(self) -> torch.Tensor:
if self.model_type == 'chatglm':
return self.chatglm_attention_mask()
if self.token_len:
return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32)
return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min
def get_position_ids(self, input_ids = None) -> torch.Tensor:
if self.visual is not None and hasattr(self.visual, 'get_position_ids') and callable(getattr(self.visual, 'get_position_ids')):
return self.visual.get_position_ids(input_ids, self.seq_len, self.token_len)
if self.model_type == 'chatglm':
return self.chatglm_position_ids()
if self.token_len:
return torch.tensor([[self.seq_len - 1]], dtype=torch.int)
return torch.arange(self.seq_len, dtype=torch.int).unsqueeze(0)
def chatglm_attention_mask(self):
if self.token_len:
return torch.zeros([1]).bool().reshape([1, 1, 1, 1])
attention_mask = torch.zeros([self.seq_len, self.seq_len], dtype=torch.bool)
for i in range(self.seq_len - 1):
attention_mask[i][-1] = True
attention_mask = attention_mask.reshape([1, 1, self.seq_len, self.seq_len])
return attention_mask
def chatglm_position_ids(self):
if self.token_len:
return torch.tensor([self.context_len, self.token_len + 1]).reshape([1, 2, 1])
position_ids_0 = torch.arange(self.seq_len, dtype=torch.int)
position_ids_1 = torch.zeros(self.seq_len, dtype=torch.int)
position_ids_0[-1] = position_ids_0[-2]
position_ids_1[-1] = 1
position_ids = torch.stack([position_ids_0, position_ids_1]).view(1, 2, -1)
return position_ids
def visual_embed(self, input_ids):
return self.visual.embed(input_ids)
def audio_embed(self, input_ids):
return self.audio.embed(input_ids)
def embedding(self, input_ids):
if self.visual is not None and self.token_len == 0:
input_embeds = self.visual_embed(input_ids)
elif self.audio is not None and self.token_len == 0:
input_embeds = self.audio_embed(input_ids)
else:
input_embeds = self.embed(input_ids)
return input_embeds
def forward(self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: Optional[List[torch.Tensor]] = None,
logits_index: int = -1,
cross_attention_states: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = input_ids # llm forward without embedding
if self.model_type == 'gemma':
normalizer = torch.tensor(self.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
presents = [None for i in range(self.num_hidden_layers)]
rotary_pos_emb = self.rotary(position_ids)
if self.args.test and rotary_pos_emb.dtype != hidden_states.dtype:
rotary_pos_emb = rotary_pos_emb.type(hidden_states.dtype)
for i in range(self.num_hidden_layers):
if self.blocks[i].cross_decoder and cross_attention_states is None:
continue
hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, past_key_values[i])
presents[i] = kv
talker_embeds = self.final_layernorm_(hidden_states) + input_ids.permute([1, 0, 2])
if hasattr(self, 'talker') and self.talker is not None:
self.talker.add_talker_embeds(talker_embeds)
logits = self.lm(hidden_states, logits_index)
if presents[0].shape == presents[-1].shape and None not in presents:
presents = torch.stack(presents)
self.seq_len += 1
self.token_len += 1
return logits, presents, talker_embeds
# some test functions
def build_prompt_template(self):
template = {
'bos': '',
'system': '{content}',
'user': '{content}',
'assistant': '{content}',
}
if self.model_type == 'baichuan':
template['user'] = '<reserved_106>{content}'
template['assistant'] = '<reserved_107>{content}'
if self.model_type == 'chatglm':
template['user'] = '{content}[gMASK]<sop>'
if self.model_type == 'chatglm2' and 'codegeex' not in self.args.path:
template['user'] = '[Round 1]\n\n问:{content}\n\n'
template['assistant'] = '答:{content}\n\n'
if 'chatglm3' in self.args.path or 'glm-4' in self.args.path:
template['bos'] = '[gMASK]<sop>'
template['system'] = '<|system|>\n{content}\n'
template['user'] = '<|user|>\n{content}\n'
template['assistant'] = '<|assistant|>\n{content}\n'
if self.model_type == 'llama':
if 'deepseek' in self.args.path:
template['bos'] = '<|begin_of_sentence|>'
template['system'] = '{content}\n'
template['user'] = '\nUser: {content}\n'
template['assistant'] = '\nAssistant: {content}\n<|end_of_sentence|>'
if 'Llama-2' in self.args.path:
template['bos'] = '[INST] '
template['system'] = "<<SYS>>\n{content}\n<</SYS>>\n\n"
template['user'] = '{content} [/INST]'
template['assistant'] = "{content}</s>";
if 'Llama-3' in self.args.path:
template['system'] = '<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>'
template['user'] = '<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>'
template['assistant'] = '<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>'
if 'TinyLlama' in self.args.path:
template['bos'] = '<s>'
template['system'] = '<|system|>\n{content}</s>\n'
template['user'] = '<|user|>\n{content}</s>\n'
template['assistant'] = '<|assistant|>\n{content}</s>\n'
if 'Yi' in self.args.path or 'SmolLM2' in self.args.path:
template['system'] = '<|im_start|>system\n{content}<|im_end|>\n'
template['user'] = '<|im_start|>user\n{content}<|im_end|>\n'
template['assistant'] = '<|im_start|>assistant\n{content}<|im_end|>\n'
if self.model_type == 'gemma2':
template['bos'] = '<bos>'
template['system'] = '<start_of_turn>system\n{content}<end_of_turn>\n'
template['user'] = '<start_of_turn>user\n{content}<end_of_turn>\n'
template['assistant'] = '<start_of_turn>model\n{content}<end_of_turn>\n'
if self.model_type == 'gemma':
template['bos'] = '<bos>'
if self.model_type == 'internlm':
template['user'] = '<|User|>:{content}<eoh>\n'
template['assistant'] = '<|Bot|>:{content}<eoh>\n'
if self.model_type == 'phi-msft':
template['user'] = 'Instruct: {content}\n'
template['assistant'] = 'Output:{content}\n'
if self.model_type == 'openelm':
template['bos'] = '<s>'
if self.model_type == 'internvl_chat':
if 'Qwen' in self.config.llm_config._name_or_path:
print("[DEBUG] Use qwen prompt template")
template['system'] = '<|im_start|>system\n{content}<|im_end|>\n'
template['user'] = '<|im_start|>user\n{content}<|im_end|>\n'
template['assistant'] = '<|im_start|>assistant\n{content}<|im_end|>\n'
if 'qwen' in self.model_type:
template['system'] = '<|im_start|>system\n{content}<|im_end|>\n'
template['user'] = '<|im_start|>user\n{content}<|im_end|>\n'
template['assistant'] = '<|im_start|>assistant\n{content}<|im_end|>\n'
if 'DeepSeek' in self.args.path:
template['bos'] = '<|begin_of_sentence|>'
template['user'] = '<|User|>{content}'
template['assistant'] = '<|Assistant|>{content}<|end_of_sentence|>'
return template
def build_prompt(self, messages):
template = self.build_prompt_template()
prompt = template['bos']
for item in messages:
role, content = item['role'], item['content']
if '{content}' in template[role]:
prompt += template[role].format(content=content)
else:
prompt += role + '\n' + content +'\n'
assistant_prefix = template['assistant'].split('{content}')[0]
return prompt + assistant_prefix
def str_to_ids(self, prompt):
if self.visual is not None:
return self.visual.str_to_ids(prompt)
if self.audio is not None:
return self.audio.str_to_ids(prompt)
input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids']
return input_ids
def id_to_str(self, token_id):
try:
word = self.tokenizer.decode(int(token_id))
except:
def contains_replacement(text): return '\uFFFD' in text
def decode_id(token_id):
return self.tokenizer.convert_tokens_to_string(
self.tokenizer._convert_id_to_token(int(token_id)))
def decode_ids(token_ids):
return self.tokenizer.convert_tokens_to_string(
self.tokenizer.convert_ids_to_tokens(token_ids))
word = decode_id(int(token_id))
# Smollm tokenizer will produce half chinese character, using buffer to decode
if contains_replacement(word):
self.decode_buffer.append(token_id)
buffer_txt = decode_ids(self.decode_buffer)
if not contains_replacement(buffer_txt):
word = buffer_txt
self.decode_buffer.clear()
else:
word = ''
return word
@torch.no_grad()
def response(self, query):
# self.imitate_quant()
self.decode_buffer = []
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": query}
]
prompt = self.build_prompt(messages)
input_ids = self.str_to_ids(prompt)
if self.visual is not None:
cross_attention_states = self.visual.cross_attention_states
cross_attention_mask = self.visual.cross_attention_mask
else:
cross_attention_states = None
cross_attention_mask = None
self.seq_len = input_ids.numel()
self.context_len = self.seq_len - 2
self.token_len = 0
past_key_values = [None for i in range(self.num_hidden_layers)]
token_id = input_ids
while self.token_len < self.max_length:
attention_mask = self.get_attention_mask()
position_ids = self.get_position_ids(token_id)
input_ids = self.embedding(token_id)
logits, past_key_values, _ = self.forward(input_ids,
attention_mask,
position_ids,
past_key_values,
cross_attention_states,
cross_attention_mask)
token_id = torch.argmax(logits[:,-1,:])
if token_id in self.stop_ids:
print("", end='\n')
break
word = self.id_to_str(token_id)
print(word, end="", flush=True)
if hasattr(self, 'talker') and self.talker is not None:
self.talker.generate()
@spinner_run(f'export embedding to ')
def export_embed(self):
import ctypes
if hasattr(self, 'word_embeddings'):
# embedding model's embed
tensor_data = self.word_embeddings.weight.data.bfloat16()
else:
tensor_data = self.embed.embed.weight.data.bfloat16()
data_ptr = tensor_data.untyped_storage().data_ptr()
buffer = (ctypes.c_byte * (tensor_data.numel() * 2)).from_address(data_ptr)
embedding_file = f'{self.args.dst_path}/embeddings_bf16.bin'
with open(embedding_file, 'wb') as f:
f.write(buffer)
return embedding_file
@spinner_run(f'export config to ')
def export_config(self, mnn_config = False):
config_json = f'{self.args.dst_path}/llm_config.json'
with open(config_json, 'w', encoding='utf-8') as f:
json.dump(self.llm_config, f, ensure_ascii=False, indent=4)
if not mnn_config:
return config_json
with open(f'{self.args.dst_path}/config.json', 'w', encoding='utf-8') as f:
config = {
"llm_model": f"{self.dst_name}.mnn",
"llm_weight": f"{self.dst_name}.mnn.weight",
"backend_type": "cpu",
"thread_num": 4,
"precision": "low",
"memory": "low",
"system_prompt": "You are a helpful assistant.",
}
if self.talker is not None:
config['system_prompt'] = "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."
config['talker_max_new_tokens'] = 2048
config['talker_speaker'] = "Chelsie"
config['dit_steps'] = 5
config['dit_solver'] = 1
if self.visual is not None or self.audio is not None:
config['mllm'] = {
'backend_type': "cpu",
"thread_num": 4,
"precision": "normal",
"memory": "low"
}
json.dump(config, f, ensure_ascii=False, indent=4)
return config_json
def imitate_quant(self):
def quant_dequant(linear, quant_bit = self.args.quant_bit, quant_block = self.args.quant_block):
weight = linear.weight.data
oc, ic = weight.shape
if quant_block == 0:
block_size = ic
else:
block_size = quant_block
block_num = ic // block_size
weight = weight.reshape(oc, block_num, block_size)
max_val = torch.max(weight, axis=-1, keepdims=True).values
min_val = torch.min(weight, axis=-1, keepdims=True).values
offset = 1 << (quant_bit - 1)
clip_max = offset - 1
clip_min = -offset
scale = (max_val - min_val) / (clip_max - clip_min)
q_weight = torch.round((weight - min_val) / scale) + clip_min
q_weight = torch.clip(q_weight, clip_min, clip_max)
dq_weight = (q_weight - clip_min) * scale + min_val
dq_weight = dq_weight.reshape(oc, ic).float()
linear.weight.data = dq_weight
return linear
with torch.no_grad():
for i in range(self.num_hidden_layers):
for name, child in self.blocks[i].self_attn.named_children():
if isinstance(child, torch.nn.Linear):
setattr(self.blocks[i].self_attn, name, quant_dequant(child))
for name, child in self.blocks[i].mlp.named_children():
if isinstance(child, torch.nn.Linear):
setattr(self.blocks[i].mlp, name, quant_dequant(child))
self.lm.lm = quant_dequant(self.lm.lm)
def unload_param(self):
self.unloaded_ops = {}
def build_faker(real, name):
faker = FakeLinear(real.in_features, real.out_features, real.bias is not None, name)
self.unloaded_ops[name] = real
return faker
# replace linear with fakelinear to save export memory and time
with torch.no_grad():
for i in range(self.num_hidden_layers):
# different kv cache shape in different layers
if isinstance(self.num_attention_heads, list):
self.blocks[i].self_attn.export_fused_attn = True
for name, child in self.blocks[i].self_attn.named_children():
if isinstance(child, torch.nn.Linear):
setattr(self.blocks[i].self_attn, name, build_faker(child, f'/layers.{i}/self_attn/{name}/Linear'))
for name, child in self.blocks[i].mlp.named_children():
if isinstance(child, torch.nn.Linear):
setattr(self.blocks[i].mlp, name, build_faker(child, f'/layers.{i}/mlp/{name}/Linear'))
self.lm.lm = build_faker(self.lm.lm, f'/lm/lm_head/Linear')
@spinner_run(f'export model weight to ')
def onnx_load_param(self, onnx_path):
return OnnxRebuilder(onnx_path, self.unloaded_ops).rebuild()
@spinner_run(f'slim the graph of ')
def slim_onnx(self, onnx_model):
import onnxslim
model = onnxslim.slim(onnx_model)
onnx.save(model, onnx_model)
return onnx_model
@spinner_run(f'export onnx model to ')
def export_onnx(self):
# unload linear weight to save export memory
self.unload_param()
model = self
self.seq_len = 3
self.token_len = 0
input_ids = torch.arange(3, dtype=torch.long)
attention_mask = self.get_attention_mask()
position_ids = self.get_position_ids(input_ids)
onnx_model = f'{self.onnx_path}/{self.dst_name}.onnx'
# For export onnx, don't need image or audio's embedding
input_ids = self.embed(input_ids)
past_key_values = torch.zeros(self.past_kv_shape)
logits_index = torch.tensor([-1], dtype=torch.int32)
if hasattr(self, 'talker') and self.talker is not None:
output_names = ['logits', 'presents', 'talker_embeds']
else:
output_names = ['logits', 'presents']
# export to onnx
torch.onnx.export(
model, (input_ids, attention_mask, position_ids, past_key_values, logits_index),
onnx_model,
input_names=[
'input_ids', 'attention_mask', 'position_ids', 'past_key_values', 'logits_index'
],
output_names=output_names,
dynamic_axes=self.model_dynamic_axes,
do_constant_folding=True,
verbose=False,
opset_version=15)
return onnx_model
def awq_quant(self):
self.awq_quantizer = AwqQuantizer(self)
self.awq_quantizer.quantize()
self.is_awq_quantized = True
def export_vision(self):
if self.visual is None:
return
vision_onnx = self.visual.export(self.onnx_path)
if self.mnn_converter: self.mnn_converter.export(vision_onnx, self.visual.quant_bit)
def export_audio(self):
if self.audio is None:
return
audio_onnx = self.audio.export(self.onnx_path)
if self.mnn_converter: self.mnn_converter.export(audio_onnx, self.audio.quant_bit)
def export_talker(self):
if self.talker is None:
return
talker_onnx = self.talker.export(self.onnx_path)
predit_onnx, dit_onnx, bigvgan_onnx = self.talker.token2wav.export(self.onnx_path)
if self.mnn_converter:
self.mnn_converter.export(talker_onnx, self.talker.quant_bit)
self.mnn_converter.export(predit_onnx, self.talker.token2wav.quant_bit)
self.mnn_converter.export(dit_onnx, self.talker.token2wav.quant_bit)
self.mnn_converter.export(bigvgan_onnx, self.talker.token2wav.quant_bit)
def export_language(self):
# export_embedding
if self.mnn_converter and self.tie_word_embeddings:
pass # mnn tie_word_embeddings need't export embedding
else:
self.export_embed()
# export transformer
onnx_model = self.export_onnx()
if self.args.onnx_slim:
self.slim_onnx(onnx_model)
if self.mnn_converter:
MNNConveter(self, self.unloaded_ops).export(onnx_model)
else:
self.onnx_load_param(onnx_model)
def export(self, export_type):
if self.args.awq:
self.awq_quant()
export_mnn = export_type == 'mnn'
self.mnn_converter = MNNConveter(self) if export_mnn else None
self.export_talker()
self.export_vision()
self.export_audio()
self.export_language()
self.export_tokenizer()
self.export_config(export_mnn)
if export_mnn:
# delete onnx file
try:
for file in glob.glob(f'{self.onnx_path}/*'):
os.remove(file)
os.rmdir(self.onnx_path)
except Exception as e:
print(f"remove onnx error: {e}")
@spinner_run(f'export tokenizer to ')
def export_tokenizer(self):
# load tokenizer file
tokenizer_model = os.path.join(self.args.tokenizer_path, 'tokenizer.model')
ice_text_model = os.path.join(self.args.tokenizer_path, 'ice_text.model')
try:
import sentencepiece as spm
if os.path.exists(tokenizer_model):
self.sp_model = spm.SentencePieceProcessor(tokenizer_model)
elif os.path.exists(ice_text_model):
self.sp_model = spm.SentencePieceProcessor(ice_text_model)
else:
self.sp_model = None
except:
self.sp_model = None
merge_file = os.path.join(self.args.path, 'merges.txt')
if os.path.exists(merge_file):
self.merge_txt = merge_file
else:
self.merge_txt = None
# TOKENIZER MAGIC NUMBER
MAGIC_NUMBER = 430
# TOKENIZER TYPE
SENTENCEPIECE = 0; TIKTOIKEN = 1; BERT = 2; HUGGINGFACE = 3
def write_line(fp, *args):
for arg in args:
for token in arg:
fp.write(str(token) + ' ')
fp.write('\n')
def write_header(fp, type, speicals, prefix = []):
fp.write(f'{MAGIC_NUMBER} {type}\n')
fp.write(f'{len(speicals)} {len(self.stop_ids)} {len(prefix)}\n')
write_line(fp, speicals, self.stop_ids, prefix)
file_path = os.path.join(self.args.dst_path, "tokenizer.txt")
special_list = list(self.tokenizer.added_tokens_decoder.keys())
if hasattr(self.tokenizer, 'special_tokens'):
for k, v in self.tokenizer.special_tokens.items():
special_list.append(v)
if hasattr(self.tokenizer, 'gmask_token_id'):
special_list.append(self.tokenizer.gmask_token_id)
if hasattr(self.model, 'generation_config') and self.model.generation_config is not None:
generation_config = self.model.generation_config
if hasattr(generation_config, 'user_token_id'):
special_list.append(generation_config.user_token_id)
if hasattr(generation_config, 'assistant_token_id'):
special_list.append(generation_config.assistant_token_id)
vocab_list = []
prefix_list = []
if hasattr(self.tokenizer, 'get_prefix_tokens'):
prefix_list = self.tokenizer.get_prefix_tokens()
if len(prefix_list) == 0:
try:
test_txt = 'A'
ids = self.tokenizer.encode(test_txt)
get_txt = self.tokenizer.decode(ids[-1])
if len(ids) > 1 and get_txt == test_txt:
prefix_list += ids[:-1]
except:
pass
if self.sp_model is not None:
# senetencepiece
NORMAL = 1; UNKNOWN = 2; CONTROL = 3
USER_DEFINED = 4; UNUSED = 5; BYTE = 6
for i in range(self.sp_model.GetPieceSize()):
token = self.sp_model.IdToPiece(i)
score = self.sp_model.GetScore(i)
token_type = NORMAL
if self.sp_model.IsUnknown(i):
token_type = UNKNOWN
elif self.sp_model.IsControl(i):
token_type = CONTROL
elif self.sp_model.IsUnused(i):
token_type = UNUSED
elif self.sp_model.IsByte(i):
token_type = BYTE
if self.args.path == 'Chatglm_6b':
if '<n>' in token: token = '\n'
if '<|tab|>' in token: token = '\t'
if '<|blank_' in token: token = ' ' * int(token[8:token.find('|>')])
if '▁' in token: token = token.replace('▁', ' ')
token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8")
vocab_list.append(f'{token_encode} {score} {token_type}\n')
with open(file_path, "w", encoding="utf8") as fp:
write_header(fp, SENTENCEPIECE, special_list, prefix_list)
fp.write(f'{len(vocab_list)}\n')
for vocab in vocab_list:
fp.write(vocab)
elif hasattr(self.tokenizer, 'mergeable_ranks'):
# tikton
vocab_list = []
for k, v in self.tokenizer.mergeable_ranks.items():
line = base64.b64encode(k).decode("utf8") + "\n"
vocab_list.append(line)
if hasattr(self.tokenizer, 'special_tokens'):
for k, v in self.tokenizer.special_tokens.items():
line = base64.b64encode(k.encode("utf-8")).decode("utf8") + "\n"
vocab_list.append(line)
if hasattr(self.tokenizer, 'added_tokens_decoder'):
for k, v in self.tokenizer.added_tokens_decoder.items():
line = base64.b64encode(v.__str__().encode("utf-8")).decode("utf8") + "\n"
vocab_list.append(line)
with open(file_path, "w", encoding="utf8") as fp:
write_header(fp, TIKTOIKEN, special_list, prefix_list)
fp.write(f'{len(vocab_list)}\n')
for vocab in vocab_list:
fp.write(vocab)
elif self.merge_txt is not None:
# huggingface tokenizer
merge_list = []
vocab = self.tokenizer.get_vocab()
special_list = list(self.tokenizer.added_tokens_decoder.keys())
vocab_list = ['<unk>' for i in range(len(vocab))]
# load vocab
for k, v in vocab.items():
vocab_list[int(v)] = k
# load merge
with open(self.merge_txt, 'rt') as merge:
for line in merge.readlines():
merge_list.append(line)
# write to tokenizer.txt
with open(file_path, "w", encoding="utf8") as fp:
write_header(fp, HUGGINGFACE, special_list)
fp.write(f'{len(vocab_list)} {len(merge_list)}\n')
for v in vocab_list:
fp.write(v + '\n')
for m in merge_list:
fp.write(m)
else:
# tiktoken or bert
if 'bert' in type(self.tokenizer).__name__.lower():
tokenizer_type = BERT
else:
tokenizer_type = TIKTOIKEN
# bert tokenizer
def unicode_to_byte(u: int):
if u >= 256 and u <= 288:
return u - 256
if u >= 289 and u <= 322:
return u - 162
if u == 323:
return 173
if u == 65372: # |
return 124
if u == 9601: # _
return 95
return u
vocab = self.tokenizer.get_vocab()
vocab_list = ['<unk>' for i in range(len(vocab))]
for k, v in vocab.items():
try:
vocab_list[int(v)] = bytes([unicode_to_byte(ord(c)) for c in k])
except:
vocab_list[int(v)] = k.encode('utf-8')
special_list = list(self.tokenizer.added_tokens_decoder.keys())
with open(file_path, "w", encoding="utf8") as fp:
write_header(fp, tokenizer_type, special_list)
fp.write(f'{len(vocab_list)}\n')
for v in vocab_list:
line = base64.b64encode(v).decode("utf8") + "\n"
fp.write(line)
return file_path
class EmbeddingExporter(LlmExporter):
def __init__(self, args):
super().__init__(args)
self.dst_name = 'embedding'
def word_embed(self, input_ids):
return self.word_embeddings(input_ids.view(1, -1))
def bge_forward(self, inputs_embeds, position_ids, attention_mask):
# bert absolute position
inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size)
position_embeddings = self.position_embeddings(position_ids)
embeddings = inputs_embeds + position_embeddings + self.token_type_embeddings
hidden_states = self.embedding_layernorm(embeddings)
for i in range(self.num_hidden_layers):
hidden_states = self.blocks[i](hidden_states, attention_mask)[0]
sentence_embeddings = hidden_states[:, 0]
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
def gte_forward(self, inputs_embeds, position_ids, attention_mask):
# rope position
inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size)
freqs = position_ids.float().reshape(-1, 1) * self.inv_freq
emb = torch.cat((freqs, freqs), dim=-1)
rope_embeds = torch.stack([emb.cos(), emb.sin()]).unsqueeze(-2).unsqueeze(1)
attention_bias = 1 - attention_mask.float()
hidden_states = self.embedding_layernorm(inputs_embeds + self.token_type_embeddings)
for i in range(self.num_hidden_layers):
hidden_states = self.blocks[i](hidden_states, attention_bias, rope_embeds)[0]
sentence_embeddings = hidden_states[:, 0]
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
def forward(self, inputs_embeds, position_ids, attention_mask):
if self.model_type == 'bert':
return self.bge_forward(inputs_embeds, position_ids, attention_mask)
if self.model_type == 'new':
return self.gte_forward(inputs_embeds, position_ids, attention_mask)
raise RuntimeError(f'Not support embedding model: {self.model_type}!')
def response(self, query):
self.eval()
input_ids = self.tokenizer(query)['input_ids']
self.seq_len = len(input_ids)
input_ids = torch.tensor(input_ids)
position_ids = self.get_position_ids()
attention_mask = self.get_attention_mask()
inputs_embeds = self.word_embed(input_ids)
res = self.forward(inputs_embeds, position_ids, attention_mask)
# print(res)
return res
@spinner_run(f'load pretrained model ')
def load_model(self, model_path):
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.config = AutoConfig.from_pretrained(model_path)
self.config._attn_implementation = 'eager'
self.model = AutoModel.from_config(self.config)
transformer = self.model.encoder
self.model_type = self.config.model_type
self.lm_ = self.model.pooler
self.embed_ = self.model.embeddings
self.word_embeddings = self.embed_.word_embeddings
self.token_type_embeddings = self.embed_.token_type_embeddings.weight.data[0]
self.embedding_layernorm = self.embed_.LayerNorm
if hasattr(self.embed_, 'position_embeddings'):
self.position_embeddings = self.embed_.position_embeddings
self.hidden_size = self.word_embeddings.weight.shape[-1]
self.blocks = transformer.layer
if self.model_type == 'new':
self.inv_freq = self.embed_.rotary_emb.inv_freq
# some wrapper
self.stop_ids = []
self.num_hidden_layers = len(self.blocks)
self.embed = self.embed_
self.lm = self.lm_
# some config for export
self.model_dynamic_axes = {
"input_ids" : { 1: "seq_len" },
"position_ids" : { 1: "seq_len" },
"attention_mask" : { 3: "seq_len" }
}
self.attention_mask_type = 'int'
self.llm_config = {
'hidden_size' : self.hidden_size,
'layer_nums' : self.num_hidden_layers,
'attention_mask': self.attention_mask_type,
'key_value_shape': [],
"prompt_template": self.build_prompt('%s'),
'is_visual': False
}
return model_path
@spinner_run(f'export onnx model to ')
def export_onnx(self):
model = self.eval()
self.seq_len = 3
input_ids = torch.arange(3, dtype=torch.long)
position_ids = self.get_position_ids()
attention_mask = self.get_attention_mask()
inputs_embeds = self.word_embed(input_ids)
onnx_model = f'{self.onnx_path}/{self.dst_name}.onnx'
torch.onnx.export(
model, (inputs_embeds, position_ids, attention_mask),
onnx_model,
input_names=[
'input_ids',
'position_ids',
'attention_mask'
],
output_names=['sentence_embeddings'],
dynamic_axes=self.model_dynamic_axes,
do_constant_folding=True,
opset_version=15)
return onnx_model
def export(self, export_type):
export_mnn = 'mnn' in export_type
self.export_tokenizer()
self.export_config(export_mnn)
self.export_embed()
onnx_model = self.export_onnx()
if self.args.onnx_slim:
self.slim_onnx(onnx_model)
if export_mnn:
MNNConveter(onnx_model, None, self).export()
def build_prompt(self, content):
if self.model_type == 'bert':
return f'[CLS]{content}[SEP]'
if self.model_type == 'new':
return f'<s> {content}</s>'
def get_position_ids(self) -> torch.Tensor:
return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
def get_attention_mask(self) -> torch.Tensor:
return torch.ones([1, 1, 1, self.seq_len], dtype=torch.long)
def export(path,
type = None,
tokenizer_path = None,
lora_path = None,
gptq_path = None,
dst_path = './model',
export = 'onnx',
onnx_slim = False,
quant_bit = 4,
quant_block = 128,
lm_quant_bit = None,
mnnconvert = None,
ppl = False,
awq = False,
sym = False,
tie_embed = False,
lora_split = False):
args = argparse.Namespace()
for k, v in {
'path': path,
'type': type,
'tokenizer_path': tokenizer_path,
'lora_path': lora_path,
'gptq_path': gptq_path,
'dst_path': dst_path,
'export': export,
'onnx_slim': onnx_slim,
'quant_bit': quant_bit,
'quant_block': quant_block,
'lm_quant_bit': lm_quant_bit,
'mnnconvert': mnnconvert,
'ppl': ppl,
'awq': awq,
'sym': sym,
'tie_embed': tie_embed,
'lora_split': lora_split
}.items():
setattr(args, k, v)
if 'bge' in path:
llm_exporter = EmbeddingExporter(args)
else:
llm_exporter = LlmExporter(args)
# export
llm_exporter.export(export)
def main():
parser = argparse.ArgumentParser(description='llm_exporter', formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('--path', type=str, required=True,
help='path(`str` or `os.PathLike`):\nCan be either:'
'\n\t- A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO]'
'\n\t- A path to a *directory* clone from repo like `../chatglm-6b`.')
parser.add_argument('--type', type=str, default=None,
help='type(`str`, *optional*):'
'\n\tThe pretrain llm model type.'
)
parser.add_argument('--tokenizer_path', type=str, default=None, help='tokenizer path, defaut is `None` mean using `--path` value.')
parser.add_argument('--lora_path', type=str, default=None, help='lora path, defaut is `None` mean not apply lora.')
parser.add_argument('--gptq_path', type=str, default=None, help='gptq path, defaut is `None` mean not apply gptq.')
parser.add_argument('--dst_path', type=str, default='./model', help='export onnx/mnn model to path, defaut is `./model`.')
parser.add_argument('--verbose', action='store_true', help='Whether or not to print verbose.')
parser.add_argument('--test', type=str, help='test model inference with query `TEST`.')
parser.add_argument('--export', type=str, default=None, help='export model to an onnx/mnn model.')
parser.add_argument('--onnx_slim', action='store_true', help='Whether or not to use onnx-slim.')
parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.')
parser.add_argument('--quant_block', type=int, default=128, help='mnn quant block, default is 0 mean channle-wise.')
parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.')
parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.')
parser.add_argument('--ppl', action='store_true', help='Whether or not to get all logits of input tokens.')
parser.add_argument('--awq', action='store_true', help='Whether or not to use awq quant.')
parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), defualt is False.')
parser.add_argument('--tie_embed', action='store_true', help='Whether or not to using tie_embedding, defualt is False.')
parser.add_argument('--lora_split', action='store_true', help='Whether or not export lora split, defualt is False.')
args = parser.parse_args()
model_path = args.path
model_type = args.type
if 'gte' in model_path or 'bge' in model_path:
llm_exporter = EmbeddingExporter(args)
else:
llm_exporter = LlmExporter(args)
# some actions
if args.test is not None:
llm_exporter.response(args.test)
if args.export is not None:
llm_exporter.export(args.export)
if __name__ == '__main__':
main()