maga_transformer/models/falcon.py (77 lines of code) (raw):

import os import json import functools from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters from maga_transformer.utils.model_weight import W, CkptWeightInfo, identity, transpose, qkv_gather from maga_transformer.model_loader.model_weight_info import ModelWeightInfo, ModelDeployWeightInfo from maga_transformer.model_loader.weight_module import AtomicWeight from maga_transformer.model_loader.ffn_weight import FfnAtomicWeight, FfnWeight from maga_transformer.model_loader.attn_weight import AttnAtomicWeight from maga_transformer.models.base_model import BaseModel from maga_transformer.model_factory_register import register_model class FalconWeightInfo(ModelDeployWeightInfo): def _process_meta(self, meta_dicts, weight_keys): if 'transformer.h.0.ln_attn.weight' in weight_keys: self.falcon_40b = True elif 'transformer.h.0.input_layernorm.weight' in weight_keys: self.falcon_40b = False def _get_weight_info(self): attn_config = self.attn_config ffn_config = self.ffn_config weights = [ AtomicWeight(W.embedding, [CkptWeightInfo('transformer.word_embeddings.weight', identity)], identity), AtomicWeight(W.lm_head, [CkptWeightInfo('lm_head.weight', identity)], identity), AtomicWeight(W.final_ln_gamma, [CkptWeightInfo('transformer.ln_f.weight', identity)], identity), AtomicWeight(W.final_ln_beta, [CkptWeightInfo('transformer.ln_f.bias', identity)], identity), ] layer_weights = [ AttnAtomicWeight(W.attn_o_w, [CkptWeightInfo('transformer.h.{i}.self_attention.dense.weight', identity)], transpose, config=attn_config), FfnWeight(sub_weights=[ FfnAtomicWeight(W.ffn_w3, [CkptWeightInfo('transformer.h.{i}.mlp.dense_h_to_4h.weight', identity)], transpose, config=ffn_config), FfnAtomicWeight(W.ffn_w2, [CkptWeightInfo('transformer.h.{i}.mlp.dense_4h_to_h.weight', identity)], transpose, config=ffn_config)], config=ffn_config) ] if self.falcon_40b: layer_weights.extend([ AttnAtomicWeight(W.attn_qkv_w, [CkptWeightInfo('transformer.h.{i}.self_attention.query_key_value.weight', identity)], functools.partial(qkv_gather, dim0=self._hidden_size, head_num=self._head_num, head_num_kv=self._head_num_kv), config=attn_config), AtomicWeight(W.pre_ln_beta, [CkptWeightInfo('transformer.h.{i}.ln_mlp.bias', identity)], identity), AtomicWeight(W.pre_ln_gamma, [CkptWeightInfo('transformer.h.{i}.ln_mlp.weight', identity)], identity), AtomicWeight(W.pre_attn_ln_beta, [CkptWeightInfo('transformer.h.{i}.ln_attn.bias', identity)], identity), AtomicWeight(W.pre_attn_ln_gamma, [CkptWeightInfo('transformer.h.{i}.ln_attn.weight', identity)], identity), ]) else: layer_weights.extend([ AttnAtomicWeight(W.attn_qkv_w, [CkptWeightInfo('transformer.h.{i}.self_attention.query_key_value.weight', identity)], transpose, config=attn_config), AtomicWeight(W.pre_ln_beta, [CkptWeightInfo('transformer.h.{i}.input_layernorm.bias', identity)], identity), AtomicWeight(W.pre_ln_gamma, [CkptWeightInfo('transformer.h.{i}.input_layernorm.weight', identity)], identity), ]) return ModelWeightInfo(layer_weights=layer_weights, weights=weights) class Falcon(BaseModel): @staticmethod def get_weight_cls(): return FalconWeightInfo @classmethod def _create_config(cls, ckpt_path: str): config_path = os.path.join(ckpt_path, 'config.json') with open(config_path) as f: config_json = json.load(f) head_num = config_json.get('n_head', config_json.get('num_attention_heads')) config = GptInitModelParameters( head_num=head_num, head_num_kv=config_json.get('n_head_kv', config_json.get('num_kv_heads', 1)), size_per_head=config_json['hidden_size'] // head_num, inter_size=config_json['hidden_size'] * 4, layer_num=config_json.get('n_layer', config_json.get('num_hidden_layers')), max_seq_len=2048, vocab_size=config_json['vocab_size'], activation_type='gelu-none-approximate', has_post_decoder_layernorm=True, rotary_embedding_style=1, ckpt_path=ckpt_path) config.special_tokens.bos_token_id = config_json['bos_token_id'] config.special_tokens.eos_token_id = config_json['eos_token_id'] config.rotary_embedding_dim = config.size_per_head config.tie_word_embeddings = config_json.get('tie_word_embeddings', False) return config register_model('falcon', Falcon, ["FalconForCausalLM"])