maga_transformer/models/llama.py (186 lines of code) (raw):
import os
import logging
import json
import math
from typing import Any, Dict, List
from transformers.models.llama.tokenization_llama import LlamaTokenizer as LlamaTokenizerOrigin
from maga_transformer.distribute.worker_info import ParallelInfo, g_parallel_info
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
from maga_transformer.models.llama_weight import LlamaWeightInfo, GemmaWeightInfo
from maga_transformer.models.base_model import BaseModel
from maga_transformer.model_factory_register import register_model
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
class LlamaTokenizer(LlamaTokenizerOrigin):
def convert_tokens_to_string(self, tokens: List[int]):
if len(tokens) == 0:
return ""
return super().convert_tokens_to_string(tokens)
class Llama(BaseModel):
@staticmethod
def get_mscale(scale: float):
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
@staticmethod
def get_weight_cls():
return LlamaWeightInfo
@classmethod
def _create_config(cls, ckpt_path: str):
config = GptInitModelParameters(
head_num=0,
size_per_head=0,
layer_num=0,
max_seq_len=0,
vocab_size=0,
ckpt_path=ckpt_path,
activation_type='SiGLU',
norm_type='rmsnorm',
rotary_embedding_dim=128,
rotary_embedding_style=1,
has_post_decoder_layernorm=True,
)
# hugggingface
config_path = os.path.join(ckpt_path, 'config.json')
# llama-int8
param_path = os.path.join(ckpt_path, 'params.json')
if os.path.exists(config_path):
with open(config_path) as reader:
content = reader.read()
content = content.replace("LlamaForCausalLM", "LLaMAForCausalLM")
config_json = json.loads(content)
Llama.from_huggingface(config, config_json)
elif os.path.exists(param_path):
logging.info("llama not find config.json, use default config")
with open(param_path) as reader:
param_json = json.loads(reader.read())
config_json = param_json
Llama.from_params(config, param_json)
else:
raise Exception("llama parameter from unkown source")
return config
@staticmethod
def from_huggingface(config, config_json: Dict[str, Any]):
config.head_num = config_json['num_attention_heads']
config.head_num_kv = config_json.get('num_key_value_heads', config.head_num)
config.hidden_size = config_json['hidden_size']
config.size_per_head = config_json['hidden_size'] // config_json['num_attention_heads']
config.size_per_head = config_json.get('head_dim', config.size_per_head)
config.layer_num = config_json['num_hidden_layers']
config.max_seq_len = config_json.get('max_sequence_length', 2048)
config.vocab_size = config_json['vocab_size']
config.layernorm_eps = config_json.get('rms_norm_eps', config_json.get('layer_norm_eps', 1e-05))
config.inter_size = config_json['intermediate_size']
config.rotary_embedding_base = config_json.get('rope_theta', 10000)
config.rotary_embedding_dim = config.size_per_head
config.tie_word_embeddings = config_json.get('tie_word_embeddings', False)
rope_scaling = config_json.get('rope_scaling')
if rope_scaling is not None:
rope_type = rope_scaling.get('type', rope_scaling.get('rope_type'))
if rope_type == 'linear':
config.rotary_embedding_scale = rope_scaling['factor']
config.org_embedding_max_pos = config_json.get('max_position_embeddings', 2048)
elif rope_type == 'dynamic':
config.rotary_embedding_style = 3
elif rope_type == 'yarn':
config.rotary_embedding_style = 5
config.rotary_embedding_scale = rope_scaling['factor']
config.rotary_factor1 = rope_scaling.get('beta_slow', 1)
config.rotary_factor2 = rope_scaling.get('beta_fast', 32)
config.org_embedding_max_pos = rope_scaling['original_max_position_embeddings']
config.rotary_embedding_mscale = Llama.get_mscale(config.rotary_embedding_scale)
elif rope_type == 'llama3':
config.rotary_embedding_style = 6
config.rotary_embedding_scale = rope_scaling['factor']
config.rotary_factor1 = rope_scaling['low_freq_factor']
config.rotary_factor2 = rope_scaling['high_freq_factor']
config.org_embedding_max_pos = rope_scaling['original_max_position_embeddings']
else:
raise Exception(f"unsupport rope_scaling {rope_scaling}")
# config.activation_type = config_json.get("hidden_act", config.activation_type)
config.special_tokens.bos_token_id = config_json['bos_token_id']
eos_token_id = config_json['eos_token_id']
# openai endpoint will get corrent eos token id list from tokenizer
if isinstance(eos_token_id, list):
config.special_tokens.eos_token_id = eos_token_id[0]
else:
config.special_tokens.eos_token_id = eos_token_id
config.use_logn_attn = config_json.get("use_logn_attn", False)
@staticmethod
def from_params(config: GptInitModelParameters, params_json: Dict[str, Any]):
config.head_num = params_json['n_heads']
config.head_num_kv = params_json.get('n_kv_heads', config.head_num)
config.size_per_head = params_json['dim'] // params_json['n_heads']
config.layer_num = params_json['n_layers']
config.max_seq_len = 2048
config.vocab_size = 32000
config.layernorm_eps = params_json['norm_eps']
config.inter_size = compute_intermediate_size(
params_json['dim'],
params_json.get("ffn_dim_multiplier", 1),
params_json['multiple_of'])
config.special_tokens.bos_token_id = 1
config.special_tokens.eos_token_id = 2
config.rotary_embedding_dim = config.size_per_head
config.tie_word_embeddings = params_json.get('tie_word_embeddings', False)
return config
@classmethod
def get_tokenizer(cls, config: GptInitModelParameters):
tokenizer_config_file = os.path.join(config.tokenizer_path, "tokenizer_config.json")
if os.path.exists(tokenizer_config_file):
logging.info("load super tokenzier")
return super().get_tokenizer(config)
else:
logging.info("load LlamaTokenizer")
return LlamaTokenizer.from_pretrained(config.tokenizer_path)
class Baichuan(Llama):
@classmethod
def _create_config(cls, ckpt_path: str):
config = Llama._create_config(ckpt_path)
if config.layer_num == 40: # 13B
config.rotary_embedding_style = 0
config.rotary_embedding_dim = 0
config.use_attention_linear_bias = True
config.special_tokens.bos_token_id = -1
config.special_tokens.user.token_ids = [195]
config.special_tokens.user.eos_token_ids = []
config.special_tokens.assistant.token_ids = [196]
config.special_tokens.assistant.eos_token_ids = [config.special_tokens.eos_token_id]
return config
class Baichuan2(Baichuan):
@classmethod
def _create_config(cls, ckpt_path: str):
config = Baichuan._create_config(ckpt_path)
config.normalize_lm_head_weight = True
return config
class Gemma(Llama):
def __init__(self, config: GptInitModelParameters):
if os.environ.get("ENABLE_OPENSOURCE_FMHA", None) != "OFF":
logging.warn("opensource fmha does not support head dim 256, thus disabled for gemma model")
os.environ["ENABLE_OPENSOURCE_FMHA"] = "OFF"
super().__init__(config)
@staticmethod
def get_weight_cls():
return GemmaWeightInfo
@classmethod
def _create_config(cls, ckpt_path: str):
config = Llama._create_config(ckpt_path)
config.has_post_decoder_layernorm = True
config.input_embedding_scalar = (config.hidden_size ** 0.5)
config.rotary_embedding_dim = config.size_per_head
config.activation_type = 'gated-gelu'
return config
class Cohere(Llama):
@classmethod
def _create_config(cls, ckpt_path: str):
config = Llama._create_config(ckpt_path)
config.rotary_embedding_style = 0
config.norm_type = 'layernorm'
config.qk_norm = True
return config
register_model('internlm', Llama, ["InternLMForCausalLM"])
register_model('internlm2', Llama, ["InternLM2ForCausalLM"])
register_model('llama', Llama, ["LlamaForCausalLM", "YiForCausalLM"])
register_model('xverse', Llama, ["XverseForCausalLM"])
register_model('aquila', Llama, ["AquilaModel"])
register_model('mistral', Llama, ["MistralForCausalLM"])
register_model('baichuan', Baichuan, ["BaichuanForCausalLM"])
register_model('baichuan2', Baichuan2)
register_model('gemma', Gemma, ["GemmaForCausalLM"])
register_model('cohere', Cohere, ["CohereForCausalLM"])