trending_deploy/models.py (76 lines of code) (raw):
from typing import Iterator, List
from huggingface_hub import list_models, ModelInfo, hf_hub_download, model_info as get_model_info
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
import json
from tqdm import tqdm, trange
from trending_deploy.constants import Instance, Model, MEMORY_USAGE_TO_INSTANCE
def trending_models(tasks: list[str], max_models_per_task: int = 200) -> List[Model]:
"""
Fetches the trending models for the specified tasks.
Args:
tasks (list[str] | None): A list of task names. If None, defaults to DEFAULT_TASKS.
max_models_per_task (int): The maximum number of models to fetch per task.
budget (int | None): The budget for the tasks in monthly dollar spend.
Returns:
List[Model]: A list of Model objects containing model information and viable instance.
"""
models_to_consider: list[Model] = []
tasks_iterator = tqdm(tasks, leave=False)
for task in tasks_iterator:
tasks_iterator.set_description(f"Loading trending models for {task}")
models_to_consider.extend(trending_models_for_task(task, max_models_per_task))
return models_to_consider
def trending_models_for_task(task: str, max_models_per_task: int = 200) -> List[Model]:
"""
Fetches the trending models for a specific task.
Args:
task (str): The task for which to fetch trending models.
max_models_per_task (int): The maximum number of models to fetch per task.
Returns:
List[Model]: A list of Model objects containing model information and viable instance.
"""
models_to_consider: list[Model] = []
trending_model_gen = trending_model_generator(task)
try:
for _ in trange(max_models_per_task, desc="Loading models", leave=False):
models_to_consider.append(next(trending_model_gen))
except StopIteration:
pass
return models_to_consider
def trending_model_generator(task: str) -> Iterator[Model]:
for model_info in list_models(
pipeline_tag=task,
tags="endpoints_compatible",
expand=[
"createdAt",
"trendingScore",
"tags",
"library_name",
"likes",
"downloads",
"downloadsAllTime",
"safetensors",
"pipeline_tag",
],
):
if "custom_code" in model_info.tags:
continue
# Get the number of parameters to determine which instance type is viable.
# Sometimes it may fail (e.g. non-authorized models), so we just skip those models.
try:
num_parameters = get_num_parameters_from_model(model_info)
except Exception:
continue
if num_parameters is None:
continue
viable_instance: Instance = get_viable_instance_from_num_parameters(num_parameters)
if viable_instance is None:
continue
yield Model(model_info=model_info, viable_instance=viable_instance)
def get_num_parameters_from_model(model: ModelInfo):
safetensors = model.safetensors
if safetensors:
return safetensors.total
bytes_per_param = 4
files = get_model_info(model.id, files_metadata=True).siblings
for file in files:
if file.rfilename == "pytorch_model.bin":
return file.size // bytes_per_param
if file.rfilename == "pytorch_model.bin.index.json":
disable_progress_bars()
index_path = hf_hub_download(model.id, filename="pytorch_model.bin.index.json")
enable_progress_bars()
"""
{
"metadata": {
"total_size": 28272820224
},....
"""
index = json.load(open(index_path))
if ("metadata" in index) and ("total_size" in index["metadata"]):
return index["metadata"]["total_size"] // bytes_per_param
return None
def get_viable_instance_from_num_parameters(num_parameters: int):
model_memory_usage_bytes = num_parameters * 4
memory_factor = 2.2
viable_instance = None
for max_instance_memory_usage, instance in MEMORY_USAGE_TO_INSTANCE.items():
if model_memory_usage_bytes * memory_factor < max_instance_memory_usage:
viable_instance = instance
break
return viable_instance