src/model/weights.py (31 lines of code) (raw):
import os
import shutil
import torch
from transformers import PretrainedConfig
from safetensors.torch import save_file
from loguru import logger
from huggingface_hub.constants import HF_HOME
from huggingface_hub.file_download import repo_folder_name
NO_WEIGHTS_CACHE_DIR = os.path.join(HF_HOME, "no_weights_models")
if os.getenv("CLEAN_CACHE_DIR", "0") == "1" and os.path.exists(NO_WEIGHTS_CACHE_DIR):
shutil.rmtree(NO_WEIGHTS_CACHE_DIR)
def download_no_weights_model(model_id: str):
# Create base no_weights directory
os.makedirs(NO_WEIGHTS_CACHE_DIR, exist_ok=True)
model_folder_name = repo_folder_name(repo_id=model_id, repo_type="model")
model_folder_path = os.path.join(HF_HOME, model_folder_name)
no_weights_model_path = os.path.join(NO_WEIGHTS_CACHE_DIR, model_folder_name)
# if the model is already in the no_weights_models set or if the model is already downloaded, do nothing
if os.path.exists(no_weights_model_path):
return
if os.path.exists(model_folder_path):
os.symlink(model_folder_path, no_weights_model_path)
return
try:
pretrained_config = PretrainedConfig.from_pretrained(model_id)
except Exception as e:
raise ValueError(f"Failed to load config from {model_id}: {str(e)}")
# Create and save dummy state dict, this is our dummy state to replace the weights download
state_dict = torch.nn.Linear(1, 1).state_dict()
# Create the model directory before saving files
os.makedirs(no_weights_model_path, exist_ok=True)
# Save safetensors file
safetensors_path = os.path.join(no_weights_model_path, "model.safetensors")
save_file(tensors=state_dict, filename=safetensors_path, metadata={"format": "pt"})
# Save config
logger.info("Saving model config")
pretrained_config.save_pretrained(save_directory=no_weights_model_path)