trending_deploy/models.py (76 lines of code) (raw):

from typing import Iterator, List from huggingface_hub import list_models, ModelInfo, hf_hub_download, model_info as get_model_info from huggingface_hub.utils import disable_progress_bars, enable_progress_bars import json from tqdm import tqdm, trange from trending_deploy.constants import Instance, Model, MEMORY_USAGE_TO_INSTANCE def trending_models(tasks: list[str], max_models_per_task: int = 200) -> List[Model]: """ Fetches the trending models for the specified tasks. Args: tasks (list[str] | None): A list of task names. If None, defaults to DEFAULT_TASKS. max_models_per_task (int): The maximum number of models to fetch per task. budget (int | None): The budget for the tasks in monthly dollar spend. Returns: List[Model]: A list of Model objects containing model information and viable instance. """ models_to_consider: list[Model] = [] tasks_iterator = tqdm(tasks, leave=False) for task in tasks_iterator: tasks_iterator.set_description(f"Loading trending models for {task}") models_to_consider.extend(trending_models_for_task(task, max_models_per_task)) return models_to_consider def trending_models_for_task(task: str, max_models_per_task: int = 200) -> List[Model]: """ Fetches the trending models for a specific task. Args: task (str): The task for which to fetch trending models. max_models_per_task (int): The maximum number of models to fetch per task. Returns: List[Model]: A list of Model objects containing model information and viable instance. """ models_to_consider: list[Model] = [] trending_model_gen = trending_model_generator(task) try: for _ in trange(max_models_per_task, desc="Loading models", leave=False): models_to_consider.append(next(trending_model_gen)) except StopIteration: pass return models_to_consider def trending_model_generator(task: str) -> Iterator[Model]: for model_info in list_models( pipeline_tag=task, tags="endpoints_compatible", expand=[ "createdAt", "trendingScore", "tags", "library_name", "likes", "downloads", "downloadsAllTime", "safetensors", "pipeline_tag", ], ): if "custom_code" in model_info.tags: continue # Get the number of parameters to determine which instance type is viable. # Sometimes it may fail (e.g. non-authorized models), so we just skip those models. try: num_parameters = get_num_parameters_from_model(model_info) except Exception: continue if num_parameters is None: continue viable_instance: Instance = get_viable_instance_from_num_parameters(num_parameters) if viable_instance is None: continue yield Model(model_info=model_info, viable_instance=viable_instance) def get_num_parameters_from_model(model: ModelInfo): safetensors = model.safetensors if safetensors: return safetensors.total bytes_per_param = 4 files = get_model_info(model.id, files_metadata=True).siblings for file in files: if file.rfilename == "pytorch_model.bin": return file.size // bytes_per_param if file.rfilename == "pytorch_model.bin.index.json": disable_progress_bars() index_path = hf_hub_download(model.id, filename="pytorch_model.bin.index.json") enable_progress_bars() """ { "metadata": { "total_size": 28272820224 },.... """ index = json.load(open(index_path)) if ("metadata" in index) and ("total_size" in index["metadata"]): return index["metadata"]["total_size"] // bytes_per_param return None def get_viable_instance_from_num_parameters(num_parameters: int): model_memory_usage_bytes = num_parameters * 4 memory_factor = 2.2 viable_instance = None for max_instance_memory_usage, instance in MEMORY_USAGE_TO_INSTANCE.items(): if model_memory_usage_bytes * memory_factor < max_instance_memory_usage: viable_instance = instance break return viable_instance