maga_transformer/model_factory_register.py (82 lines of code) (raw):
import os
import json
import logging
from typing import Any, Dict, Type, List, Optional
import sys
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(str(CUR_PATH), ".."))
_model_factory: Dict[str, Type[Any]] = {}
def register_model(name: str, model_type: Any, support_architectures: List[str] = [], support_hf_repos: List[str] = []):
global _model_factory
if name in _model_factory and _model_factory[name] != model_type:
raise Exception(f"try register model {name} with type {_model_factory[name]} and {model_type}, confict!")
_model_factory[name] = model_type
for architecture in support_architectures:
register_hf_architecture(architecture, name)
for repo in support_hf_repos:
register_hf_repo(repo, name)
_hf_architecture_2_ft = {}
def register_hf_architecture(name: str, model_type: str):
global _hf_architecture_2_ft
if name in _hf_architecture_2_ft and _hf_architecture_2_ft[name] != model_type:
raise Exception(f"try register model {name} with type {_hf_architecture_2_ft[name]} and {model_type}, confict!")
logging.debug(f"registerhf_architecture: {name} -> {model_type}")
_hf_architecture_2_ft[name] = model_type
_hf_repo_2_ft = {}
def register_hf_repo(name: str, model_type: str):
global _hf_repo_2_ft
if name in _hf_repo_2_ft and _hf_repo_2_ft[name] != model_type:
raise Exception(f"try register model {name} with type {_hf_repo_2_ft[name]} and {model_type}, confict!")
logging.debug(f"register_hf_repo: {name} -> {model_type}")
_hf_repo_2_ft[name] = model_type
class ModelDict:
@staticmethod
def get_ft_model_type_by_hf_repo(repo: str) -> Optional[str]:
global _hf_repo_2_ft
model_type = _hf_repo_2_ft.get(repo, None)
logging.debug(f"get hf_repo model type: {repo}, {model_type}")
return model_type
@staticmethod
def get_ft_model_type_by_hf_architectures(architecture):
global _hf_architecture_2_ft
model_type = _hf_architecture_2_ft.get(architecture, None)
logging.debug(f"get architectur model type: {architecture}, {model_type}")
return model_type
@staticmethod
def get_ft_model_type_by_config(config: Dict[str, Any]) ->Optional[str]:
if config.get('architectures', []):
# hack for ChatGLMModel: chatglm and chatglm2 use same architecture
architecture = config.get('architectures')[0]
if architecture in ['ChatGLMModel', 'ChatGLMForConditionalGeneration']:
_name_or_path = config.get('_name_or_path', '')
if not config.get('multi_query_attention', False) or 'chatglm-6b' in _name_or_path:
return 'chatglm'
elif 'chatglm3' in _name_or_path:
return 'chatglm3'
elif 'glm-4-' in _name_or_path:
return 'chatglm4'
elif 'glm-4v' in _name_or_path:
return 'chatglm4v'
else:
return 'chatglm2'
if architecture == 'QWenLMHeadModel':
if config.get('visual'):
if config['visual'].get('layers'):
return 'qwen_vl'
else:
return 'qwen_vl_1b8'
if architecture == 'BaichuanForCausalLM':
vocab_size = config.get('vocab_size', 64000)
if vocab_size == 125696:
return 'baichuan2'
else:
return 'baichuan'
if architecture == 'GPTNeoXForCausalLM':
vocab_size = config.get('vocab_size', 50432)
if vocab_size == 250752:
return 'gpt_neox_13b'
else:
return 'gpt_neox'
return ModelDict.get_ft_model_type_by_hf_architectures(architecture)
else:
logging.warning(f"config have no architectures: {config}")
return None