maga_transformer/tools/fake_model_base.py (189 lines of code) (raw):

import torch from typing import Dict, List import argparse import os import json from maga_transformer.tools.fake_util import generate_fake_model, copy_from_model from maga_transformer.utils.util import load_ckpt from maga_transformer.utils.model_weight import W from maga_transformer.model_factory import ModelFactory def default_save_config_func(model_type, dest_path: str, layer: int, head: int, head_kv: int, head_size: int, ffn_hidden_size: int, ffn_inter_padding_size: int, vocab_size: int): config = { "model_type": model_type, "add_bias_linear": False, "add_qkv_bias": True, "apply_query_key_layer_scaling": True, "apply_residual_connection_post_layernorm": False, "attention_dropout": 0.0, "attention_softmax_in_fp32": True, "bias_dropout_fusion": True, "ffn_hidden_size": ffn_hidden_size, "ffn_inter_padding_size": ffn_inter_padding_size, "fp32_residual_connection": False, "hidden_dropout": 0.0, "hidden_size": head * head_size, "kv_channels": head_size, "layernorm_epsilon": 1e-05, "multi_query_attention": True, "multi_query_group_num": head_kv, "num_attention_heads": head, "num_layers": layer, "original_rope": True, "padded_vocab_size": vocab_size, "post_layer_norm": True, "rmsnorm": True, "seq_length": 32768, "use_cache": True, "torch_dtype": "float16", "tie_word_embeddings": False, "eos_token_id": 2, "pad_token_id": 0 } # save to config.json json.dump(config, open(os.path.join(dest_path, 'config.json'), 'w'), indent=2) def fake_model_impl(model_type: str, save_config_func, post_rewrite_func, dest_path: str, layer_num: int, head_num: int, head_kv_num: int, head_size: int, ffn_hidden_size: int, ffn_inter_padding_size: int, ffn_gate_active: bool, ffn_w1_w3_independ: bool, vocab_size: int, input_model: str | None): hidden_size = head_num * head_size qkv_hidden_size = hidden_size + head_kv_num * head_size * 2 new_params: Dict[str, torch.Tensor] = {} model_weight_info = ModelFactory.get_weight_cls(model_type)( hidden_size=hidden_size, inter_size=ffn_hidden_size * 2, num_heads=head_num, num_heads_kv=head_kv_num, tp_size=1, int8_mode=0, num_layers=layer_num) model_weight_info._lm_head = False model_weight_info._transformer_prefix = False weight_info = model_weight_info.get_weight_info() print(weight_info) weight_dict = { W.embedding : "", W.lm_head : "", W.pre_decoder_ln_gamma : "", W.pre_decoder_ln_beta : "", W.final_ln_gamma : "", W.final_ln_beta : ""} for weight_item in weight_info.weights: if weight_item.name in weight_dict and len(weight_item.weights) > 0: weight_dict[weight_item.name] = weight_item.weights[0].name shape_map: Dict[str, List[int]] = { weight_dict[W.embedding]: [vocab_size, hidden_size], weight_dict[W.lm_head]: [vocab_size, hidden_size], weight_dict[W.pre_decoder_ln_gamma]: [hidden_size], weight_dict[W.pre_decoder_ln_beta]: [hidden_size], weight_dict[W.final_ln_gamma]: [hidden_size], weight_dict[W.final_ln_beta]: [hidden_size], } layer_weight_dict = {W.pre_ln_gamma: "", W.pre_ln_beta: "", W.attn_qkv_w: "", W.attn_qkv_b: "", W.attn_ln_gamma: "", W.attn_ln_beta: "", W.attn_o_w: "", W.attn_o_b: "", W.ffn_w1: "", W.ffn_b1: "", W.ffn_ln_gamma: "", W.ffn_ln_beta: "", W.ffn_w2: "", W.ffn_b2: "", W.ffn_w3: "", W.ffn_b3: "", W.post_ln_gamma: "", W.post_ln_beta: ""} for layer_weight_item in weight_info.layer_weights: if layer_weight_item.name in layer_weight_dict and len(layer_weight_item.weights) > 0: layer_weight_dict[layer_weight_item.name] = layer_weight_item.weights[0].name for i in range(layer_num): shape_map[layer_weight_dict[W.pre_ln_gamma].format(i=str(i))] = [hidden_size] shape_map[layer_weight_dict[W.pre_ln_beta].format(i=str(i))] = [hidden_size] shape_map[layer_weight_dict[W.attn_qkv_w].format(i=str(i))] = [qkv_hidden_size, hidden_size] shape_map[layer_weight_dict[W.attn_qkv_b].format(i=str(i))] = [qkv_hidden_size] shape_map[layer_weight_dict[W.attn_ln_gamma].format(i=str(i))] = [hidden_size] shape_map[layer_weight_dict[W.attn_ln_beta].format(i=str(i))] = [hidden_size] shape_map[layer_weight_dict[W.attn_o_w].format(i=str(i))] = [hidden_size , hidden_size] shape_map[layer_weight_dict[W.attn_o_b].format(i=str(i))] = [hidden_size] if ffn_gate_active: if ffn_w1_w3_independ: shape_map[layer_weight_dict[W.ffn_w1].format(i=str(i))] = [ffn_hidden_size, hidden_size] shape_map[layer_weight_dict[W.ffn_b1].format(i=str(i))] = [ffn_hidden_size] shape_map[layer_weight_dict[W.ffn_w3].format(i=str(i))] = [ffn_hidden_size, hidden_size] shape_map[layer_weight_dict[W.ffn_b3].format(i=str(i))] = [ffn_hidden_size] else: shape_map[layer_weight_dict[W.ffn_w1].format(i=str(i))] = [ffn_hidden_size * 2, hidden_size] shape_map[layer_weight_dict[W.ffn_b1].format(i=str(i))] = [ffn_hidden_size * 2] else: shape_map[layer_weight_dict[W.ffn_w1].format(i=str(i))] = [ffn_hidden_size, hidden_size] shape_map[layer_weight_dict[W.ffn_b1].format(i=str(i))] = [ffn_hidden_size] shape_map[layer_weight_dict[W.ffn_ln_gamma].format(i=str(i))] = [ffn_hidden_size] shape_map[layer_weight_dict[W.ffn_ln_beta].format(i=str(i))] = [ffn_hidden_size] shape_map[layer_weight_dict[W.ffn_w2].format(i=str(i))] = [hidden_size, ffn_hidden_size] shape_map[layer_weight_dict[W.ffn_b2].format(i=str(i))] = [hidden_size] shape_map[layer_weight_dict[W.post_ln_gamma].format(i=str(i))] = [hidden_size] shape_map[layer_weight_dict[W.post_ln_beta].format(i=str(i))] = [hidden_size] if "" in shape_map: del shape_map[""] assert(len(shape_map) > 0) print("shape_map = ", shape_map) file_name = f"fake_{model_type}_{layer_num}_{head_num}_{head_kv_num}_{head_size}_{ffn_hidden_size}_{vocab_size}" if input_model: model = load_ckpt(input_model) model_weight_info.process_meta(model) new_params = copy_from_model(shape_map, model) file_name += "_copy" else: new_params = generate_fake_model(shape_map) file_name += ".pt" if not os.path.exists(dest_path): print(f'{dest_path} not exist, creating...') os.makedirs(dest_path) # save config print("saving config.json...") if save_config_func: save_config_func(model_type, dest_path, layer_num, head_num, head_kv_num, head_size, ffn_hidden_size, ffn_inter_padding_size, vocab_size) # save model print("saving model...") file_name = os.path.join(dest_path, file_name) if post_rewrite_func: new_params = post_rewrite_func(new_params) print(new_params) torch.save(new_params, file_name) print(f"save finished, save path: {dest_path}") class DefaultModelConfig: layer_num: int = 0 head_num: int = 0 head_kv_num: int = 0 head_size: int = 0 ffn_hidden_size: int = 0 ffn_inter_padding_size: int = 0 ffn_gate_active: bool = True ffn_w1_w3_independ: bool = False vocab_size: int = 0 def fake_model(model_type: str, default_values: DefaultModelConfig, save_config_func = None, post_rewrite_func = None): parser = argparse.ArgumentParser() parser.add_argument('--path', '-p', help='saved path', required=True) parser.add_argument('--layer', '-l', help='layer number', default=default_values.layer_num, type=int) parser.add_argument('--head', '-d', help='head number', default=default_values.head_num, type=int) parser.add_argument('--head_kv', '-k', help='kv head number', default=default_values.head_kv_num, type=int) parser.add_argument('--head_size', '-s', help='head size', default=default_values.head_size, type=int) parser.add_argument('--ffn_hidden_size', '-f', help='ffn hidden size', default=default_values.ffn_hidden_size, type=int) parser.add_argument('--ffn_inter_padding_size', '-e', help='ffn inter padding size', default=default_values.ffn_inter_padding_size, type=int) parser.add_argument('--ffn_gate_active', '-g', help='ffn gate active', default=default_values.ffn_gate_active, type=bool) parser.add_argument('--ffn_w1_w3_independ', '-r', help='ffn w1 w3 weight independ', default=default_values.ffn_w1_w3_independ, type=bool) parser.add_argument('--vocab', '-v', help='vocab size', default=default_values.vocab_size, type=int) parser.add_argument('--input', '-i', help='input model', default=None, type=str) args = parser.parse_args() fake_model_impl(model_type, save_config_func, post_rewrite_func, args.path, args.layer, args.head, args.head_kv, args.head_size, args.ffn_hidden_size, args.ffn_inter_padding_size, args.ffn_gate_active, args.ffn_w1_w3_independ, args.vocab, args.input)