maga_transformer/models/gpt_neox.py (115 lines of code) (raw):
from typing import Any, Dict
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
from maga_transformer.utils.util import get_config_from_path
from maga_transformer.models.gpt_neox_weight import GPTNeoxWeight, GPTNeox13BWeight
from maga_transformer.models.base_model import BaseModel
from maga_transformer.model_factory_register import register_model
class GPTNeox(BaseModel):
@staticmethod
def get_weight_cls():
return GPTNeoxWeight
@classmethod
def _create_config(cls, ckpt_path: str):
config_dict = get_config_from_path(ckpt_path)
if config_dict:
config = GPTNeox.from_huggingface(config_dict)
config.ckpt_path = ckpt_path
else:
config = GptInitModelParameters(
head_num=40,
head_num_kv=40,
size_per_head=128,
layer_num=40,
max_seq_len=4096,
vocab_size=250752,
eos_token_id=2,
inter_size = 20480,
inter_padding_size = 20480)
config.rotary_embedding_dim = 128
config.rotary_embedding_style = 1
config.has_pre_decoder_layernorm = False
config.has_post_decoder_layernorm = True
config.norm_type = 'layernorm'
config.use_norm_input_residual = True
return config
@staticmethod
def from_huggingface(config_json: Dict[str, Any]):
config = GptInitModelParameters(head_num=40,
size_per_head=128,
layer_num=40,
max_seq_len=4096,
vocab_size=250752)
config.head_num = config_json['num_attention_heads']
config.head_num_kv = config.head_num
config.size_per_head = config_json['hidden_size'] // config_json['num_attention_heads']
config.layer_num = config_json['num_hidden_layers']
config.vocab_size = config_json['vocab_size']
config.layernorm_eps = config_json['layer_norm_eps']
config.inter_size = config_json['intermediate_size']
config.inter_padding_size = config.inter_size
config.special_tokens.bos_token_id = config_json['bos_token_id']
config.special_tokens.eos_token_id = config_json['eos_token_id']
config.rotary_embedding_dim = int(config.size_per_head * config_json.get('rotary_pct', 1.0))
config.rotary_embedding_style = 1
if config_json.get('rope_scaling', None):
config.rotary_embedding_style = 3
config.rotary_embedding_scale = config_json['rope_scaling']['factor']
config.org_embedding_max_pos = config_json.get('max_position_embeddings', 2048)
config.has_pre_decoder_layernorm = False
config.has_post_decoder_layernorm = True
config.norm_type = 'layernorm'
config.use_norm_input_residual = True
config.tie_word_embeddings = config_json.get('tie_word_embeddings', False)
return config
class GPTNeox13B(GPTNeox):
@staticmethod
def get_weight_cls():
return GPTNeox13BWeight
@classmethod
def _create_config(cls, ckpt_path: str):
config_dict = get_config_from_path(ckpt_path)
if config_dict:
config = GPTNeox13B.from_huggingface(config_dict)
else:
config = GptInitModelParameters(
head_num=40,
head_num_kv=40,
size_per_head=128,
layer_num=40,
max_seq_len=4096,
vocab_size=250752,
inter_size = 20480,
inter_padding_size = 20480)
config.ckpt_path = ckpt_path
config.rotary_embedding_dim = 128
config.rotary_embedding_style = 1
config.has_pre_decoder_layernorm = False
config.has_post_decoder_layernorm = True
config.norm_type = 'rmsnorm'
config.special_tokens.eos_token_id = 2
return config
@staticmethod
def from_huggingface(config_json: Dict[str, Any]):
config = GptInitModelParameters(head_num=40,
size_per_head=128,
layer_num=40,
max_seq_len=4096,
vocab_size=250752)
config.head_num = config_json['num_attention_heads']
config.head_num_kv = config.head_num
config.size_per_head = config_json['hidden_size'] // config_json['num_attention_heads']
config.layer_num = config_json['num_hidden_layers']
config.vocab_size = config_json['vocab_size']
config.layernorm_eps = config_json['layer_norm_eps']
config.inter_size = config_json['intermediate_size']
config.inter_padding_size = config.inter_size
config.special_tokens.bos_token_id = config_json['bos_token_id']
config.special_tokens.eos_token_id = config_json['eos_token_id']
if config_json.get('rope_scaling', None):
if config_json['rope_scaling']['type'] == 'dynamic':
config.rotary_embedding_style = 3
config.rotary_embedding_scale = config_json['rope_scaling']['factor']
config.org_embedding_max_pos = config_json.get('max_position_embeddings', 2048)
return config
register_model('gpt_neox', GPTNeox, ["GPTNeoXForCausalLM"])
register_model('gpt_neox_13b', GPTNeox13B)