scripts/collect_models.py (247 lines of code) (raw):

import os import argparse import re import time import json import csv import logging from requests.exceptions import RequestException from huggingface_hub.errors import HfHubHTTPError from pathlib import Path from typing import Set, Dict, Tuple, List, Any, Callable from huggingface_hub import HfApi, hf_hub_url, HfFileSystem import onnx from tqdm import tqdm import gc from parser import stream_parse_model_header # Set up logging logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") ALLOWED_AUTHORS = [ "hf-internal-testing", "onnx-internal-testing", "Xenova", "onnx-community", "distil-whisper", "HuggingFaceTB", "microsoft", "mixedbread-ai", "Mozilla", "nomic-ai", "jinaai", "lightonai", "llava-hf", "Marqo", "Snowflake", "ds4sd", "sentence-transformers", "briaai", "nomic-ai", "Alibaba-NLP", "AdamCodd", "jonathandinu", "Supabase", "WhereIsAI", "llava-hf", "Oblix", "Intel", "teapotai", "ai4privacy", "BritishWerewolf", "OuteAI", "ylacombe", ] BANNED_REPOS = [ "briaai/RMBG-2.0", "AdamCodd/distilroberta-nsfw-prompt-stable-diffusion", "AdamCodd/vit-nsfw-stable-diffusion", ] CACHE_DIR = Path(__file__).parent.parent / "data" / "model-explorer" ALLOWED_QUANTIZATIONS = ["fp16", "uint8", "int8", "quantized", "q4", "q4f16", "bnb4"] DISALLOWED_FILE_PATTERNS = [ re.compile(r'decoder(_with_past)?_model(?!_merged)'), ] ALLOWED_QUANTIZATION_PATTERNS = re.compile( r'^(.+?)(?:_(' + "|".join(ALLOWED_QUANTIZATIONS) + r'))?\.onnx$' ) ALLOWED_REPO_PATTERN = re.compile(r"tiny-random-\w+(?:For\w+|Model)"); def get_operators(model: onnx.ModelProto) -> Set[str]: """ Recursively traverses the ONNX graph and returns a set of operator names. Args: model: Loaded ONNX model. Returns: Set of operator names. """ operators: Set[str] = set() def traverse_graph(graph: onnx.GraphProto): for node in graph.node: operators.add(node.op_type) for attr in node.attribute: if attr.type == onnx.AttributeProto.GRAPH: traverse_graph(attr.g) traverse_graph(model.graph) return operators def retry_operation(func: Callable, *args, max_retries: int = 10, initial_delay: int = 1, **kwargs) -> Any: """ Retry a given operation with exponential backoff. Args: func: Callable function to execute. *args: Positional arguments for the function. max_retries: Maximum retry attempts. initial_delay: Initial delay between retries. **kwargs: Keyword arguments for the function. Returns: The function result or False if a specific RuntimeError occurs. Raises: Exception if max retries are exceeded or on non-retryable errors. """ delay = initial_delay for attempt in range(max_retries): try: return func(*args, **kwargs) except RuntimeError: return False except (RequestException, HfHubHTTPError) as e: status = getattr(e, "response", None) if status is not None and (status.status_code == 429 or 500 <= status.status_code < 600): logging.warning( f"Attempt {attempt + 1} failed with status {status.status_code}. Retrying after {delay} seconds..." ) time.sleep(delay) delay *= 2 else: raise raise Exception("Max retries exceeded while trying to execute operation.") def collect_model_ops( model_limit: int = None, from_cache: bool = False, limit: int = None, include_all_models: bool = False, ) -> None: """ Collects the operators used by ONNX models from the Hugging Face Hub, downloads models as needed, and generates JavaScript files with the aggregated model metadata. Args: model_limit: Maximum number of models to query from the Hub. from_cache: Whether to only load models from the local cache folder. limit: Maximum number of unique models to process. include_all_models: If True, also includes models from transformers.js library. """ api = HfApi() fs = HfFileSystem() logging.info("Collecting models from the Hugging Face Hub...") onnx_models = list( api.list_models( library="onnx", search="tiny-random", limit=model_limit, sort="downloads", direction=-1, fetch_config=True ) ) if include_all_models: tfjs_models = list( api.list_models( library="transformers.js", limit=model_limit, sort="downloads", direction=-1, fetch_config=True ) ) onnx_models.extend(tfjs_models) logging.info("Found %d models.", len(onnx_models)) unique_models: Dict[str, Any] = {} model_types: Dict[str, str] = {} model_architectures: Dict[str, str] = {} for model in onnx_models: repo_id = model.modelId if model.private or model.gated: logging.info("Skipping private or gated model: %s", repo_id) continue author = repo_id.split('/')[0] if author not in ALLOWED_AUTHORS: logging.info("Skipping unauthorized author: %s", repo_id) continue if repo_id in BANNED_REPOS: logging.info("Skipping banned model: %s", repo_id) continue if not include_all_models and not ALLOWED_REPO_PATTERN.search(repo_id): logging.info("Skipping disallowed model: %s", repo_id) continue unique_models[repo_id] = model cfg = getattr(model, "config", {}) or {} model_types[repo_id] = cfg.get("model_type", "unknown") model_architectures[repo_id] = cfg.get("architectures", []) # Limit unique models based on downloads models_sorted = sorted(unique_models.items(), key=lambda x: x[1].downloads, reverse=True) if limit is not None: models_sorted = models_sorted[:limit] unique_models = dict(models_sorted) logging.info("Processing %d unique models.", len(unique_models)) model_type_ops: Dict[Tuple[str, str, str], Set[str]] = {} for repo_id, model in tqdm(unique_models.items(), desc="Processing Models"): if from_cache: model_cache_folder = CACHE_DIR / repo_id files = [f"{repo_id}/onnx/{f}" for f in os.listdir(model_cache_folder)] if model_cache_folder.exists() else [] else: pattern = f"{repo_id}/**/*.onnx" result = retry_operation(fs.glob, pattern, detail=True) files = list(result.keys()) if result else [] for file_path in files: relative_path = os.path.relpath(file_path, repo_id) if not relative_path.startswith("onnx/"): continue subfolder, file_name = os.path.split(relative_path) match = ALLOWED_QUANTIZATION_PATTERNS.match(file_name) if not match: continue if any(p.search(file_name) for p in DISALLOWED_FILE_PATTERNS): logging.info("Skipping disallowed file: %s/%s/%s", repo_id, subfolder, file_name) continue quantization = match.group(2) or "fp32" model_proto = None cache_folder = CACHE_DIR / repo_id cache_folder.mkdir(exist_ok=True, parents=True) cache_path = cache_folder / file_name if cache_path.exists(): model_proto = onnx.load(str(cache_path), load_external_data=False) elif not from_cache: logging.info('Downloading model "%s/%s/%s"', repo_id, subfolder, file_name) url = hf_hub_url(repo_id=repo_id, subfolder=subfolder, filename=file_name) model_proto = retry_operation(stream_parse_model_header, url) if model_proto: onnx.save(model_proto, str(cache_path)) if model_proto: ops_set = get_operators(model_proto) m_type = model_types.get(repo_id, "unknown") key = (m_type, repo_id, quantization) if key not in model_type_ops: model_type_ops[key] = set() model_type_ops[key].update(ops_set) del model_proto gc.collect() architecture_ops: Dict[str, List[Tuple[str, str, Set[str]]]] = {} for (m_type, model_id, q), ops_set in model_type_ops.items(): if m_type not in architecture_ops: architecture_ops[m_type] = [] else: # Avoid duplicating operations already covered if any(existing_ops == ops_set for _, _, existing_ops in architecture_ops[m_type]): continue architecture_ops[m_type].append((model_id, q, ops_set)) # Generate JS files core_dir = Path(__file__).parent.parent / "packages/core/src" arch_dir = core_dir / "architectures" arch_dir.mkdir(parents=True, exist_ok=True) for m_type, model_list in architecture_ops.items(): js_path = arch_dir / f"{m_type}.js" models_data = [ {"model_id": model_id, "dtype": quantization, "architectures": model_architectures[model_id], "ops": sorted(list(ops))} for model_id, quantization, ops in sorted(model_list, key=lambda x: x[0]) ] template = ( f"// NOTE: This file has been auto-generated. Do not edit directly.\n\n" f"export default {{ model_type: '{m_type}', models: {json.dumps(models_data)} }}\n" ) with js_path.open("w") as f: f.write(template) arch_index_path = core_dir / "architectures.js" with arch_index_path.open("w") as fp: fp.write("// NOTE: This file has been auto-generated. Do not edit directly.\n") for m_type in sorted(architecture_ops.keys()): safe_m_type = m_type.replace('-', '_') fp.write(f"export {{ default as {safe_m_type} }} from './architectures/{m_type}.js';\n") download_counts = {repo_id: model.downloads for repo_id, model in unique_models.items()} rows = [] for (m_type, repo_id, quantization), ops in model_type_ops.items(): rows.append([repo_id, download_counts.get(repo_id, 0), quantization, len(ops), ", ".join(sorted(ops))]) rows.sort(key=lambda row: (download_counts.get(row[0], 0), row[2]), reverse=True) with open("./data/model_ops.csv", "w", newline="") as csvfile: writer = csv.writer(csvfile) writer.writerow(["model_id", "downloads (past month)", "quantization", "num_ops", "ops"]) writer.writerows(rows) def main() -> None: """ Parses command line arguments and initiates the collection of model operators. """ parser = argparse.ArgumentParser( description="Collect operators used in ONNX models from the Hugging Face Hub." ) parser.add_argument("--model_limit", type=int, default=None, help="Maximum number of models to query from the Hub.") parser.add_argument("--limit", type=int, default=None, help="Maximum number of unique models to process.") parser.add_argument( "--from_cache", action="store_true", help="Only use local cache for loading models." ) parser.add_argument( "--all_models", action="store_true", help="Include models from the transformers.js library in addition to tiny-random models." ) args = parser.parse_args() collect_model_ops( model_limit=args.model_limit, from_cache=args.from_cache, limit=args.limit, include_all_models=args.all_models, ) if __name__ == "__main__": main()