maga_transformer/tools/fake_qwen.py (73 lines of code) (raw):
import os
import json
from maga_transformer.tools.fake_model_base import *
def save_config_func(model_type,
dest_path: str,
layer: int,
head: int,
head_kv: int,
head_size: int,
ffn_hidden_size: int,
ffn_inter_padding_size: int,
vocab_size: int):
config = {
"activation": "swiglu",
"apply_residual_connection_post_layernorm": False,
"architectures": [
"QWenLMHeadModel"
],
"attn_pdrop": 0.0,
"auto_map": {
"AutoConfig": "configuration_qwen.QWenConfig",
"AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
},
"bf16": True,
"bias_dropout_fusion": True,
"bos_token_id": 151643,
"embd_pdrop": 0.0,
"eos_token_id": 151643,
"ffn_hidden_size": ffn_hidden_size,
"ffn_inter_padding_size": ffn_hidden_size,
"fp16": False,
"fp32": False,
"initializer_range": 0.02,
"kv_channels": head_size,
"layer_norm_epsilon": 1e-06,
"n_embd": head * head_size,
"n_head": head,
"n_layer": layer,
"n_positions": 6144,
"no_bias": True,
"padded_vocab_size": 151936,
"params_dtype": "torch.bfloat16",
"pos_emb": "rotary",
"resid_pdrop": 0.1,
"rotary_emb_base": 10000,
"rotary_pct": 1.0,
"scale_attn_weights": True,
"seq_length": 2048,
"tie_word_embeddings": False,
"tokenizer_type": "QWenTokenizer",
"torch_dtype": "bfloat16",
"transformers_version": "4.39.3",
"use_cache": True,
"use_dynamic_ntk": True,
"use_flash_attn": True,
"use_logn_attn": True,
"vocab_size": vocab_size
}
# save to config.json
json.dump(config, open(os.path.join(dest_path, 'config.json'), 'w'), indent=2)
def fake_qwen():
default_config = DefaultModelConfig()
default_config.layer_num = 2
default_config.head_num = 2
default_config.head_kv_num = 2
default_config.head_size = 128
default_config.ffn_hidden_size = 4 * default_config.head_size * default_config.head_num
default_config.ffn_inter_padding_size = 4 * default_config.head_size * default_config.head_num
default_config.ffn_gate_active = True
default_config.ffn_w1_w3_independ = True
default_config.vocab_size = 151936
fake_model("qwen_7b", default_config, save_config_func)
if __name__ == '__main__':
fake_qwen()