maga_transformer/model_factory_register.py (82 lines of code) (raw):

import os import json import logging from typing import Any, Dict, Type, List, Optional import sys CUR_PATH = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(str(CUR_PATH), "..")) _model_factory: Dict[str, Type[Any]] = {} def register_model(name: str, model_type: Any, support_architectures: List[str] = [], support_hf_repos: List[str] = []): global _model_factory if name in _model_factory and _model_factory[name] != model_type: raise Exception(f"try register model {name} with type {_model_factory[name]} and {model_type}, confict!") _model_factory[name] = model_type for architecture in support_architectures: register_hf_architecture(architecture, name) for repo in support_hf_repos: register_hf_repo(repo, name) _hf_architecture_2_ft = {} def register_hf_architecture(name: str, model_type: str): global _hf_architecture_2_ft if name in _hf_architecture_2_ft and _hf_architecture_2_ft[name] != model_type: raise Exception(f"try register model {name} with type {_hf_architecture_2_ft[name]} and {model_type}, confict!") logging.debug(f"registerhf_architecture: {name} -> {model_type}") _hf_architecture_2_ft[name] = model_type _hf_repo_2_ft = {} def register_hf_repo(name: str, model_type: str): global _hf_repo_2_ft if name in _hf_repo_2_ft and _hf_repo_2_ft[name] != model_type: raise Exception(f"try register model {name} with type {_hf_repo_2_ft[name]} and {model_type}, confict!") logging.debug(f"register_hf_repo: {name} -> {model_type}") _hf_repo_2_ft[name] = model_type class ModelDict: @staticmethod def get_ft_model_type_by_hf_repo(repo: str) -> Optional[str]: global _hf_repo_2_ft model_type = _hf_repo_2_ft.get(repo, None) logging.debug(f"get hf_repo model type: {repo}, {model_type}") return model_type @staticmethod def get_ft_model_type_by_hf_architectures(architecture): global _hf_architecture_2_ft model_type = _hf_architecture_2_ft.get(architecture, None) logging.debug(f"get architectur model type: {architecture}, {model_type}") return model_type @staticmethod def get_ft_model_type_by_config(config: Dict[str, Any]) ->Optional[str]: if config.get('architectures', []): # hack for ChatGLMModel: chatglm and chatglm2 use same architecture architecture = config.get('architectures')[0] if architecture in ['ChatGLMModel', 'ChatGLMForConditionalGeneration']: _name_or_path = config.get('_name_or_path', '') if not config.get('multi_query_attention', False) or 'chatglm-6b' in _name_or_path: return 'chatglm' elif 'chatglm3' in _name_or_path: return 'chatglm3' elif 'glm-4-' in _name_or_path: return 'chatglm4' elif 'glm-4v' in _name_or_path: return 'chatglm4v' else: return 'chatglm2' if architecture == 'QWenLMHeadModel': if config.get('visual'): if config['visual'].get('layers'): return 'qwen_vl' else: return 'qwen_vl_1b8' if architecture == 'BaichuanForCausalLM': vocab_size = config.get('vocab_size', 64000) if vocab_size == 125696: return 'baichuan2' else: return 'baichuan' if architecture == 'GPTNeoXForCausalLM': vocab_size = config.get('vocab_size', 50432) if vocab_size == 250752: return 'gpt_neox_13b' else: return 'gpt_neox' return ModelDict.get_ft_model_type_by_hf_architectures(architecture) else: logging.warning(f"config have no architectures: {config}") return None