maga_transformer/model_factory.py (213 lines of code) (raw):
import os
import json
import logging
import torch
from typing import Any, Dict, Type, Union, Optional
import sys
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(str(CUR_PATH), ".."))
from maga_transformer.models.base_model import BaseModel, ModelConfig
from maga_transformer.models.propose_model.propose_model import ProposeModel
from maga_transformer.async_decoder_engine.async_model import AsyncModel
from maga_transformer.tools.api.hf_model_helper import get_model_info_from_hf
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters, ConfigMode
from maga_transformer.utils.dump_config_utils import dump_model_to_table
from maga_transformer.utils.fuser import fetch_remote_file_to_local
from maga_transformer.utils.util import check_with_info
from maga_transformer.models.multimodal.multimodal_mixin import MultiModalMixin
from maga_transformer.distribute.worker_info import g_parallel_info
from maga_transformer.utils.weight_type import WEIGHT_TYPE, get_weight_type_from_env, get_propose_weight_type_from_env
from maga_transformer.model_factory_register import _model_factory
class ModelFactory:
@staticmethod
def get_config_json(ckpt_path: str):
check_with_info(os.path.isdir(ckpt_path), f"{ckpt_path} check os.isdir failed")
config_json_path = os.path.join(ckpt_path, 'config.json')
check_with_info(os.path.isfile(config_json_path), f"{config_json_path} check os.isdir failed")
with open(config_json_path, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
@staticmethod
def get_weight_cls(model_type: str):
global _model_factory
model_cls = _model_factory[model_type]
return model_cls.get_weight_cls()
@staticmethod
def get_model_cls(model_type: str):
global _model_factory
model_cls = _model_factory[model_type]
return model_cls
@staticmethod
def create_gpt_init_config(model_config: ModelConfig):
global _model_factory
if model_config.model_type not in _model_factory:
raise Exception(f"model type {model_config.model_type} not registered!")
model_cls = _model_factory[model_config.model_type]
config: GptInitModelParameters = model_cls.create_config(
model_config, parallel_info=g_parallel_info, config_mode=ConfigMode.SimpleMode)
config.model_name = model_cls.__name__
if issubclass(model_cls, MultiModalMixin):
config.is_multimodal = True
return model_cls, config
@staticmethod
def _create_model(model_config: ModelConfig):
global _model_factory
if model_config.model_type not in _model_factory:
raise Exception(f"model type {model_config.model_type} not registered!")
model_cls = _model_factory[model_config.model_type]
config: GptInitModelParameters = model_cls.create_config(model_config)
model = model_cls.from_config(config)
dump_model_to_table(ModelFactory.model_config_json(model_cls, model_config, config))
return model
@staticmethod
def _create_sp_model(score_model_gpt_config: GptInitModelParameters, model_config: ModelConfig):
model = None
global _model_factory
if model_config.sp_type == "vanilla" or model_config.sp_type == "mtp":
if model_config.model_type not in _model_factory:
raise Exception(f"model type {model_config.model_type} not registered!")
if model_config.model_type == "deepseek-v3-mtp" or model_config.model_type == "mixtbstars-mtp":
logging.warning(f"create sp model type is {model_config.model_type}, so change the sp type to mtp")
model_config.sp_type = "mtp"
model_cls = _model_factory[model_config.model_type]
# propose model's max seq len must be equal to score model's max seq len
model_config.max_seq_len = score_model_gpt_config.max_seq_len
config: GptInitModelParameters = model_cls.create_config(model_config)
gpt_model = model_cls.from_config(config)
dump_model_to_table(ModelFactory.model_config_json(model_cls, model_config, config))
model = ProposeModel(model_config.sp_type, model_config.gen_num_per_circle, gpt_model)
elif model_config.sp_type == "deterministic":
model = ProposeModel(model_config.sp_type, model_config.gen_num_per_circle)
elif model_config.sp_type == "eagle":
raise NotImplementedError
return model
#TODO: remove model_config, get all info from gpt_config
@staticmethod
def model_config_json(model_cls: Type[Any], model_config: ModelConfig, config: GptInitModelParameters) -> Dict[str, Any]:
config_json = {
"model_type": model_cls.__name__,
"act_type": str(model_config.act_type),
"max_seq_len": config.max_seq_len,
"use_sparse_head": config.is_sparse_head,
"use_multi_task_prompt": config.multi_task_prompt,
"lora_infos": config.lora_infos
}
return config_json
@staticmethod
def from_model_config(model_config: ModelConfig, propose_model_config: Optional[ModelConfig] = None) -> AsyncModel:
model = ModelFactory._create_model(model_config)
if model_config.model_type == 'fake_model' or model.config.vit_separation == 1:
return model
propose_model = None if propose_model_config is None else ModelFactory._create_sp_model(model.config, propose_model_config)
if propose_model:
logging.info("set enable_speculative_decoding")
model.config.enable_speculative_decoding = True
model = AsyncModel(model, propose_model)
if propose_model:
logging.info("create propose model done")
logging.info("create rpc model done")
return model
@staticmethod
def from_huggingface(model_path_or_name: str, revision: Optional[str] = None, model_config: ModelConfig = ModelConfig()):
model_path, model_type = get_model_info_from_hf(model_path_or_name, revision)
new_model_config = model_config
new_model_config = new_model_config._replace(model_type=model_type, ckpt_path=model_path, tokenizer_path=model_path)
return ModelFactory.from_model_config(new_model_config)
@staticmethod
def create_normal_model_config():
model_type = os.environ["MODEL_TYPE"]
ckpt_path = os.environ["CHECKPOINT_PATH"]
tokenizer_path = os.environ.get("TOKENIZER_PATH", ckpt_path)
lora_infos = os.environ.get("LORA_INFO", "{}")
max_seq_len = int(os.environ.get("MAX_SEQ_LEN", "0"))
seq_size_per_block = int(os.environ.get("SEQ_SIZE_PER_BLOCK", "8"))
tokenizer_path = fetch_remote_file_to_local(tokenizer_path)
ckpt_path = fetch_remote_file_to_local(ckpt_path)
extra_data_path = os.environ.get('EXTRA_DATA_PATH', "")
if extra_data_path:
extra_data_path = fetch_remote_file_to_local(extra_data_path)
os.environ['LOCAL_EXTRA_DATA_PATH'] = extra_data_path
ptuning_path = None
if 'PTUNING_PATH' in os.environ:
ptuning_path = os.environ['PTUNING_PATH']
ptuning_path = fetch_remote_file_to_local(ptuning_path)
lora_infos = json.loads(lora_infos)
for lora_name, lora_path in lora_infos.items():
lora_infos[lora_name] = fetch_remote_file_to_local(lora_path)
logging.info(f"load model from tokenizer_path: {tokenizer_path}, ckpt_path: {ckpt_path}, lora_infos: {lora_infos}, ptuning_path: {ptuning_path}")
weight_type: WEIGHT_TYPE = get_weight_type_from_env(os.environ)
act_type = weight_type if weight_type in [ WEIGHT_TYPE.FP16, WEIGHT_TYPE.BF16] else WEIGHT_TYPE.FP16
# TODO(xinfei.sxf) fix this
ACT_TYPE = "ACT_TYPE"
if os.environ.get(ACT_TYPE, None):
act_type = WEIGHT_TYPE.from_str(os.environ.get(ACT_TYPE))
model_config = ModelConfig(model_type=model_type,
ckpt_path=ckpt_path,
tokenizer_path=tokenizer_path,
weight_type=weight_type,
act_type=act_type,
max_seq_len=max_seq_len,
seq_size_per_block=seq_size_per_block,
lora_infos=lora_infos,
ptuning_path=ptuning_path)
return model_config
@staticmethod
def create_propose_model_config(normal_model_config: ModelConfig):
propose_model_config = None
sp_type = os.environ.get("SP_TYPE", None)
if sp_type == "vanilla" or sp_type == "mtp":
logging.info("use vanilla speculative model")
propose_model_type = os.environ.get("SP_MODEL_TYPE", None)
gen_num_per_circle = int(os.environ.get('GEN_NUM_PER_CIRCLE', '5'))
propose_ckpt_path = fetch_remote_file_to_local(os.environ['SP_CHECKPOINT_PATH'])
logging.info(f"load propose model from ckpt_path: {propose_ckpt_path}")
propose_weight_type = get_propose_weight_type_from_env(os.environ)
propose_act_type = propose_weight_type if propose_weight_type in [WEIGHT_TYPE.FP16, WEIGHT_TYPE.BF16] else WEIGHT_TYPE.FP16
SP_ACT_TYPE = "SP_ACT_TYPE"
if os.environ.get(SP_ACT_TYPE, None):
propose_act_type = WEIGHT_TYPE.from_str(os.environ.get(SP_ACT_TYPE))
propose_model_config = ModelConfig(model_type=propose_model_type,
ckpt_path=propose_ckpt_path,
tokenizer_path=normal_model_config.tokenizer_path,
lora_infos=None,
weight_type=propose_weight_type,
act_type=propose_act_type,
max_seq_len=normal_model_config.max_seq_len,
gen_num_per_circle=gen_num_per_circle,
sp_type=sp_type)
elif sp_type == "deterministic":
gen_num_per_circle = int(os.environ.get('GEN_NUM_PER_CIRCLE', '5'))
propose_model_config = ModelConfig(sp_type=sp_type,
gen_num_per_circle=gen_num_per_circle)
logging.info("use deterministic speculative model")
elif sp_type == "eagle":
logging.info("use eagle speculative model")
raise NotImplementedError
return propose_model_config
@staticmethod
def load_default_generate_config(model: Union[BaseModel, AsyncModel]):
if 'GENERATION_CONFIG_PATH' in os.environ:
model.default_generate_config.update(
json.load(open(os.path.join(os.environ['GENERATION_CONFIG_PATH'], 'generation_config.json')))
)
logging.info(f"load generate config:{os.environ['GENERATION_CONFIG_PATH']}/generation_config.json: \n\
{json.dumps(model.default_generate_config.model_dump(), indent=4)}"
)
@staticmethod
def create_from_env():
normal_model_config = ModelFactory.create_normal_model_config()
propose_model_config = ModelFactory.create_propose_model_config(normal_model_config)
model = ModelFactory.from_model_config(normal_model_config, propose_model_config)
ModelFactory.load_default_generate_config(model)
return model
@staticmethod
def create_from_module(ref_module: torch.nn.Module):
normal_model_config = ModelFactory.create_normal_model_config()
normal_model_config.add_ref_module(ref_module)
model = ModelFactory.from_model_config(normal_model_config)
ModelFactory.load_default_generate_config(model)
return model
@staticmethod
def create_from_dict(ref_dict: Dict[str, torch.Tensor]):
normal_model_config = ModelFactory.create_normal_model_config()
normal_model_config.add_ref_dict(ref_dict)
model = ModelFactory.from_model_config(normal_model_config)
ModelFactory.load_default_generate_config(model)
return model