maga_transformer/tools/api/hf_model_helper.py (117 lines of code) (raw):

import os import json import logging from typing import Optional, Tuple from huggingface_hub import HfApi from huggingface_hub.hf_api import ModelInfo from pathlib import Path from maga_transformer.model_factory_register import ModelDict class HfStyleModelInfo: HF_URI_PREFIX = "https://huggingface.co/" CONFIG_FILE = "config.json" MODEL_META_INFO_FILE_NAME = "model.safetensors.index.json" TORCH_BIN_INDEX_FILE_NAME = "pytorch_model.bin.index.json" api = HfApi() def __init__(self, repo_or_link: str, revision: Optional[str] = None): self._is_from_hf = self.is_from_hf(repo_or_link) self.model_info = None self.meta_info_file = None self.repo_or_link = repo_or_link if self._is_from_hf: self.model_info = self._get_model_info(repo_or_link, revision) self.model_config_file = self._get_model_config_file(repo_or_link, revision) self.meta_info_file = self._get_meta_info_file(repo_or_link, revision) else: self.model_info = None self.meta_info_file = None if os.path.exists(os.path.join(repo_or_link, HfStyleModelInfo.MODEL_META_INFO_FILE_NAME)): self.meta_info_file = os.path.join(repo_or_link, HfStyleModelInfo.MODEL_META_INFO_FILE_NAME) elif os.path.exists(os.path.join(repo_or_link, HfStyleModelInfo.TORCH_BIN_INDEX_FILE_NAME)): self.meta_info_file = os.path.join(repo_or_link, HfStyleModelInfo.TORCH_BIN_INDEX_FILE_NAME) self.model_config_file = os.path.join(repo_or_link, HfStyleModelInfo.CONFIG_FILE) self.hf_local_dir = os.path.dirname(self.model_config_file) if self.model_config_file else None # Load model information self.model_config = self._load_model_config(self.model_config_file) self.param_count, self.total_size = self._calculate_model_parameters() # Load auto config if available self.auto_config_py = self._get_auto_config_py(repo_or_link, revision) if self._is_from_hf else None def _get_model_info(self, repo_or_link: str, revision: Optional[str]) -> ModelInfo: repo = self._get_repo_from_hf_link(repo_or_link) return self.api.model_info(repo, revision=revision, timeout=10) def _get_model_config_file(self, repo_or_link: str, revision: Optional[str]): repo = self._get_repo_from_hf_link(repo_or_link) if self.api.file_exists(repo, self.CONFIG_FILE, revision=revision): return self.api.hf_hub_download(repo, self.CONFIG_FILE, revision=revision) return None def _get_meta_info_file(self, repo_or_link: str, revision: Optional[str]): repo = self._get_repo_from_hf_link(repo_or_link) if self.api.file_exists(repo, self.MODEL_META_INFO_FILE_NAME, revision=revision): return self.api.hf_hub_download(repo, self.MODEL_META_INFO_FILE_NAME, revision=revision) elif self.api.file_exists(repo, self.TORCH_BIN_INDEX_FILE_NAME, revision=revision): return self.api.hf_hub_download(repo, self.TORCH_BIN_INDEX_FILE_NAME, revision=revision) return None def _load_model_config(self, config_file: Optional[str]): logging.info(f'load config from {config_file}') if config_file and os.path.exists(config_file): with open(config_file, 'r') as f: return json.load(f) return {} def _get_auto_config_py(self, repo_or_link: str, revision: Optional[str]): config_file = self.model_config_file repo = self._get_repo_from_hf_link(repo_or_link) if config_file and os.path.exists(config_file): with open(config_file, 'r') as f: config_dict = json.load(f) config_py = config_dict.get('auto_map', {}).get('AutoConfig') if config_py: config_py_file = config_py.split('.')[0] + ".py" if config_py else None return self.api.hf_hub_download(repo, config_py_file, revision=revision) return None def _calculate_model_parameters(self) -> Tuple[Optional[int], Optional[int]]: param_count = None total_size = None if self.model_info and self.model_info.safetensors: param_count = self.model_info.safetensors.total total_size = sum(count * 2 if weight_type in ['FP16', 'BF16', 'FP32', 'FP32', "INT8", 'F16'] else count for weight_type, count in self.model_info.safetensors.parameters.items()) elif self.meta_info_file and os.path.exists(self.meta_info_file): logging.info(f'load meta_info from {self.meta_info_file}') with open(self.meta_info_file, 'r') as f: meta_info = json.load(f) total_size = meta_info.get("metadata", {}).get("total_size", None) if total_size is None and not self._is_from_hf: # try get file size from disk # standard HF patterns = ["*.safetensors", "*.bin", "*.pth", "*.pt"] total_size = 0 for pattern in patterns: for file in Path(self.repo_or_link).glob(pattern): if os.path.isfile(file): total_size += file.stat().st_size logging.info(f"fallback to get file size from disk: {total_size}") logging.info(f'{self.meta_info_file} {self.model_config_file} {self.model_info} param_count: {param_count}, total_size: {total_size}') return param_count, total_size @property def ft_model_type(self) -> Optional[str]: if self.model_info: # Assume ModelDict.get_ft_model_type_by_hf_repo() is a valid method ft_type = ModelDict.get_ft_model_type_by_hf_repo(self.model_info.modelId) if ft_type is not None: return ft_type return ModelDict.get_ft_model_type_by_config(self.model_config) @staticmethod def is_from_hf(model_path: str) -> bool: return model_path.startswith(HfStyleModelInfo.HF_URI_PREFIX) or not model_path.startswith(("oss:", "http:", "https:", "dfs:", "hdfs:", "/", "nas://")) @staticmethod def _get_repo_from_hf_link(model_link: str) -> str: return model_link.replace(HfStyleModelInfo.HF_URI_PREFIX, "") def get_hf_model_info(model_path_or_name: str, revision: Optional[str] = None): info = HfStyleModelInfo(model_path_or_name, revision) return info def get_model_info_from_hf(model_path_or_name: str, revision: Optional[str] = None) -> Tuple[str, str]: info = get_hf_model_info(model_path_or_name, revision) from huggingface_hub import snapshot_download if info.ft_model_type is None: raise Exception(f"failed to get type or type not supported in dir {model_path_or_name}") if not os.path.exists(model_path_or_name): logging.info(f"try download {model_path_or_name} from huggingface hub") local_path = snapshot_download(model_path_or_name, revision=revision) else: local_path = model_path_or_name logging.info(f"detected model type: {info.ft_model_type}, local_path: {local_path}") return local_path, info.ft_model_type