api_inference_community/hub.py (122 lines of code) (raw):

import json import logging import os import pathlib import re from typing import List, Optional from huggingface_hub import ModelCard, constants, hf_api, try_to_load_from_cache from huggingface_hub.file_download import repo_folder_name logger = logging.getLogger(__name__) def _cached_repo_root_path(cache_dir: pathlib.Path, repo_id: str) -> pathlib.Path: folder = pathlib.Path(repo_folder_name(repo_id=repo_id, repo_type="model")) return cache_dir / folder def cached_revision_path(cache_dir, repo_id, revision) -> pathlib.Path: error_msg = f"No revision path found for {repo_id}, revision {revision}" if revision is None: revision = "main" repo_cache = _cached_repo_root_path(cache_dir, repo_id) if not repo_cache.is_dir(): msg = f"Local repo {repo_cache} does not exist" logger.error(msg) raise Exception(msg) refs_dir = repo_cache / "refs" snapshots_dir = repo_cache / "snapshots" # Resolve refs (for instance to convert main to the associated commit sha) if refs_dir.is_dir(): revision_file = refs_dir / revision if revision_file.exists(): with revision_file.open() as f: revision = f.read() # Check if revision folder exists if not snapshots_dir.exists(): msg = f"No local revision path {snapshots_dir} found for {repo_id}, revision {revision}" logger.error(msg) raise Exception(msg) cached_shas = os.listdir(snapshots_dir) if revision not in cached_shas: # No cache for this revision and we won't try to return a random revision logger.error(error_msg) raise Exception(error_msg) return snapshots_dir / revision def _build_offline_model_info( repo_id: str, cache_dir: pathlib.Path, revision: str ) -> hf_api.ModelInfo: logger.info("Rebuilding offline model info for repo %s", repo_id) # Let's rebuild some partial model info from what we see in cache, info extracted should be enough # for most use cases card_path = try_to_load_from_cache( repo_id=repo_id, filename="README.md", cache_dir=cache_dir, revision=revision, ) if not isinstance(card_path, str): raise Exception( "Unable to rebuild offline model info, no README could be found" ) card_path = pathlib.Path(card_path) logger.debug("Loading model card from model readme %s", card_path) model_card = ModelCard.load(card_path) card_data = model_card.data.to_dict() repo = card_path.parent logger.debug("Repo path %s", repo) siblings = _build_offline_siblings(repo) model_info = hf_api.ModelInfo( private=False, downloads=0, likes=0, id=repo_id, card_data=card_data, siblings=siblings, **card_data, ) logger.info("Offline model info for repo %s: %s", repo, model_info) return model_info def _build_offline_siblings(repo: pathlib.Path) -> List[dict]: siblings = [] prefix_pattern = re.compile(r"^" + re.escape(str(repo)) + r"(.*)$") for root, dirs, files in os.walk(repo): for file in files: filepath = os.path.join(root, file) size = os.stat(filepath).st_size m = prefix_pattern.match(filepath) if not m: msg = ( f"File {filepath} does not match expected pattern {prefix_pattern}" ) logger.error(msg) raise Exception(msg) filepath = m.group(1) filepath = filepath.strip(os.sep) sibling = dict(rfilename=filepath, size=size) siblings.append(sibling) return siblings def _cached_model_info( repo_id: str, revision: str, cache_dir: pathlib.Path ) -> hf_api.ModelInfo: """ Looks for a json file containing prefetched model info in the revision path. If none found we just rebuild model info with the local directory files. Note that this file is not automatically created by hub_download/snapshot_download. It is just a convenience we add here, just in case the offline info we rebuild from the local directories would not cover all use cases. """ revision_path = cached_revision_path(cache_dir, repo_id, revision) model_info_basename = "hub_model_info.json" model_info_path = revision_path / model_info_basename logger.info("Checking if there are some cached model info at %s", model_info_path) if os.path.exists(model_info_path): with open(model_info_path, "r") as f: o = json.load(f) r = hf_api.ModelInfo(**o) logger.debug("Cached model info from file: %s", r) else: logger.debug( "No cached model info file %s found, " "rebuilding partial model info from cached model files", model_info_path, ) # Let's rebuild some partial model info from what we see in cache, info extracted should be enough # for most use cases r = _build_offline_model_info(repo_id, cache_dir, revision) return r def hub_model_info( repo_id: str, revision: Optional[str] = None, cache_dir: Optional[pathlib.Path] = None, **kwargs, ) -> hf_api.ModelInfo: """ Get Hub model info with offline support """ if revision is None: revision = "main" if not constants.HF_HUB_OFFLINE: return hf_api.model_info(repo_id=repo_id, revision=revision, **kwargs) logger.info("Model info for offline mode") if cache_dir is None: cache_dir = pathlib.Path(constants.HF_HUB_CACHE) return _cached_model_info(repo_id, revision, cache_dir)