optimum/tpu/model.py (43 lines of code) (raw):

import os import time from pathlib import Path from typing import Optional from huggingface_hub import snapshot_download from loguru import logger from transformers import AutoConfig from transformers.utils import SAFE_WEIGHTS_INDEX_NAME def get_export_kwargs_from_env(): batch_size = os.environ.get("MAX_BATCH_SIZE", None) if batch_size is not None: batch_size = int(batch_size) sequence_length = os.environ.get("HF_SEQUENCE_LENGTH", None) if sequence_length is not None: sequence_length = int(sequence_length) return { "task": "text-generation", "batch_size": batch_size, "sequence_length": sequence_length, } def fetch_model( model_id: str, revision: Optional[str] = None, ) -> str: """Fetch a model to local cache. Args: model_id (`str`): The *model_id* of a model on the HuggingFace hub or the path to a local model. revision (`Optional[str]`, defaults to `None`): The revision of the model on the HuggingFace hub. Returns: Model ID or path of the model available in cache. """ if os.path.isdir(model_id): if revision is not None: logger.warning("Revision {} ignored for local model at {}".format(revision, model_id)) return model_id # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model) # Note that the model may already be present in the cache. start = time.time() local_path = snapshot_download( repo_id=model_id, revision=revision, allow_patterns=["*.json", "model*.safetensors", SAFE_WEIGHTS_INDEX_NAME, "tokenizer.*"], ) end = time.time() logger.info(f"Model successfully fetched in {end - start:.2f} s.") # This will allow to set config to update specific config such as # batch_size and sequence_length. export_kwargs = get_export_kwargs_from_env() config = AutoConfig.from_pretrained(local_path) config.update(export_kwargs) config.save_pretrained(local_path) end = time.time() logger.info(f"Model config updated in {end - start:.2f} s.") return Path(local_path)