maga_transformer/models/starcoder.py (121 lines of code) (raw):

from typing import Any, Dict, List from maga_transformer.utils.util import get_config_from_path from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters from maga_transformer.utils.model_weight import W, \ CkptWeightInfo, identity, transpose, WeightStyle from maga_transformer.models.base_model import BaseModel from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast from maga_transformer.model_factory_register import register_model from maga_transformer.model_loader.weight_module import WeightModule, AtomicWeight from maga_transformer.model_loader.ffn_weight import FfnAtomicWeight, FfnConfig, FfnWeight from maga_transformer.model_loader.attn_weight import AttnAtomicWeight from maga_transformer.model_loader.model_weight_info import ModelWeightInfo, ModelDeployWeightInfo class StarcoderWeightInfo(ModelDeployWeightInfo): def _process_meta(self, meta_dicts, weight_keys): for meta_dict in meta_dicts: if self._quant_algo.isFp8() and 'transformer.h.0.attn.c_proj.weight' in meta_dict: self.weight_style = WeightStyle.TRANSFORMER_ENGINE elif self._quant_algo.isFp8() and 'transformer.layers.0.attention.dense.weight' in meta_dict: self.weight_style = WeightStyle.TRT_ENGINE def _get_weight_info(self): if self.weight_style != WeightStyle.TRT_ENGINE: embedding_tensor_name = 'transformer.wte.weight' positional_tensor_name = 'transformer.wpe.weight' else: embedding_tensor_name = 'transformer.vocab_embedding.weight' positional_tensor_name = 'transformer.position_embedding.weight' embedding_tensor_name = 'transformer.wte.weight' if self.weight_style != WeightStyle.TRT_ENGINE \ else 'transformer.vocab_embedding.weight' positional_tensor_name = 'transformer.wpe.weight' if self.weight_style != WeightStyle.TRT_ENGINE else 'transformer.position_embedding.weight' weights = [ AtomicWeight(W.embedding, [CkptWeightInfo(embedding_tensor_name, identity)], identity), AtomicWeight(W.lm_head, [CkptWeightInfo('lm_head.weight', identity)], identity), AtomicWeight(W.positional_embedding, [CkptWeightInfo(positional_tensor_name, identity)], identity), AtomicWeight(W.final_ln_gamma, [CkptWeightInfo('transformer.ln_f.weight', identity)], identity), AtomicWeight(W.final_ln_beta, [CkptWeightInfo('transformer.ln_f.bias', identity)], identity), ] # TODO(luoli.hn) lm_head gem use fp16, maybe can use fp8 gemm layer_weights: List[List[WeightModule]] = [] for layer in range(self._num_layers): w = self._get_hf_layer_weight_info(layer) layer_weights.append(w) return ModelWeightInfo(layer_weights=layer_weights, weights=weights) def _get_hf_layer_weight_info(self, layer_id: int) -> List[WeightModule]: attn_config=self.attn_config ffn_config=self.ffn_config ffn_w2_config = FfnConfig( is_gated_activation=self._is_gated_activation, inter_padding_size=self._inter_padding_size, is_moe=False, need_ffn_act_scale=self.need_ffn_act_scale ) layer_weights = [ AtomicWeight(W.pre_ln_beta, [CkptWeightInfo('transformer.h.{i}.ln_1.bias', identity)], identity), AtomicWeight(W.pre_ln_gamma, [CkptWeightInfo('transformer.h.{i}.ln_1.weight', identity)], identity), AttnAtomicWeight(W.attn_qkv_w, [CkptWeightInfo('transformer.h.{i}.attn.c_attn.weight', identity)], transpose, config=attn_config), AttnAtomicWeight(W.attn_qkv_b, [CkptWeightInfo('transformer.h.{i}.attn.c_attn.bias', identity)], identity, config=attn_config), AttnAtomicWeight(W.attn_o_w, [CkptWeightInfo('transformer.h.{i}.attn.c_proj.weight', identity)], transpose, config=attn_config), AttnAtomicWeight(W.attn_o_b, [CkptWeightInfo('transformer.h.{i}.attn.c_proj.bias', identity)], identity, config=attn_config), FfnWeight(sub_weights=[ FfnAtomicWeight(W.ffn_w3, [CkptWeightInfo('transformer.h.{i}.mlp.c_fc.weight', identity)], transpose, config=ffn_config), FfnAtomicWeight(W.ffn_b3, [CkptWeightInfo('transformer.h.{i}.mlp.c_fc.bias', identity)], identity, config=ffn_config), FfnAtomicWeight(W.ffn_w2, [CkptWeightInfo('transformer.h.{i}.mlp.c_proj.weight', identity)], transpose, config=ffn_w2_config), FfnAtomicWeight(W.ffn_b2, [CkptWeightInfo('transformer.h.{i}.mlp.c_proj.bias', identity)], identity, config=ffn_w2_config) ], config=ffn_config), AtomicWeight(W.post_ln_beta, [CkptWeightInfo('transformer.h.{i}.ln_2.bias', identity)], identity), AtomicWeight(W.post_ln_gamma, [CkptWeightInfo('transformer.h.{i}.ln_2.weight', identity)], identity), ] return layer_weights StarcoderTokenizer = GPT2TokenizerFast class StarCoder(BaseModel): @classmethod def get_tokenizer(cls, config: GptInitModelParameters): return StarcoderTokenizer.from_pretrained(config.tokenizer_path) @staticmethod def get_weight_cls(): return StarcoderWeightInfo @staticmethod def from_huggingface(ckpt_path: str, config_json: Dict[str, Any]): model_type = config_json['model_type'] config = GptInitModelParameters( head_num=config_json['n_head'], size_per_head=config_json['n_embd'] // config_json['n_head'], layer_num=config_json['n_layer'], max_seq_len=config_json.get('n_positions', 8192), vocab_size=config_json['vocab_size'], ) if model_type != 'gpt_bigcode': raise BaseException(f'model type is not starcoder: {model_type}') config.head_num_kv = 1 config.layernorm_eps = config_json['layer_norm_epsilon'] config.inter_size = config_json['n_inner'] config.special_tokens.eos_token_id = config_json['eos_token_id'] config.special_tokens.bos_token_id = config_json['bos_token_id'] # config.activation_type = config_json['activation_function'] config.has_positional_encoding = True config.has_post_decoder_layernorm = True config.tie_word_embeddings = config_json.get('tie_word_embeddings', False) return config @classmethod def _create_config(cls, ckpt_path: str): config_dict = get_config_from_path(ckpt_path) if config_dict: config = StarCoder.from_huggingface(ckpt_path, config_dict) else: config = GptInitModelParameters( head_num=48, head_num_kv=1, size_per_head=128, inter_size=4 * 6144, layer_num=40, max_seq_len=8192, vocab_size=49152, has_positional_encoding=True, has_post_decoder_layernorm=True) config.special_tokens.bos_token_id=0 config.special_tokens.eos_token_id=0 return config @classmethod def _load_quant_config(cls, ckpt_path: str, config: GptInitModelParameters): super(StarCoder, cls)._load_quant_config(ckpt_path, config) config.need_ffn_act_scale = config.quant_algo.isAwq() register_model('gpt_bigcode', StarCoder, ['GPTBigCodeForCausalLM']) register_model('wizardcoder', StarCoder)