optimum/habana/checkpoint_utils.py (97 lines of code) (raw):

import json import os from pathlib import Path import torch from huggingface_hub import list_repo_files, snapshot_download from transformers.utils import is_offline_mode def get_repo_root(model_name_or_path, local_rank=-1, token=None): """ Downloads the specified model checkpoint and returns the repository where it was downloaded. """ if Path(model_name_or_path).is_dir(): # If it is a local model, no need to download anything return model_name_or_path else: # Checks if online or not if is_offline_mode(): if local_rank == 0: print("Offline mode: forcing local_files_only=True") # Only download PyTorch weights by default if any( ".safetensors" in filename for filename in list_repo_files(model_name_or_path, token=token) ): # Some models like Falcon-180b are in only safetensors format allow_patterns = ["*.safetensors"] elif any(".bin" in filename for filename in list_repo_files(model_name_or_path, token=token)): allow_patterns = ["*.bin"] else: raise TypeError("Only PyTorch models are supported") # Download only on first process if local_rank in [-1, 0]: cache_dir = snapshot_download( model_name_or_path, local_files_only=is_offline_mode(), cache_dir=os.getenv("TRANSFORMERS_CACHE", None), allow_patterns=allow_patterns, max_workers=16, token=token, ) if local_rank == -1: # If there is only one process, then the method is finished return cache_dir # Make all processes wait so that other processes can get the checkpoint directly from cache if torch.distributed.is_initialized(): torch.distributed.barrier() return snapshot_download( model_name_or_path, local_files_only=is_offline_mode(), cache_dir=os.getenv("TRANSFORMERS_CACHE", None), allow_patterns=allow_patterns, token=token, ) def get_checkpoint_files(model_name_or_path, local_rank, token=None): cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token) # Extensions: .bin | .safetensors | .pt # Creates a list of paths from all downloaded files in cache dir from transformers import modeling_utils if any(file.suffix == ".bin" for file in Path(cached_repo_dir).rglob("*")): (name, ext) = os.path.splitext(modeling_utils.WEIGHTS_NAME) elif any(file.suffix == ".safetensors" for file in Path(cached_repo_dir).rglob("*")): (name, ext) = os.path.splitext(modeling_utils.SAFE_WEIGHTS_NAME) else: (name, ext) = ("*", ".pt") file_list = [ str(entry) for entry in Path(cached_repo_dir).rglob("*") if (entry.is_file() and entry.name.startswith(name) and entry.name.endswith(ext)) ] return file_list def write_checkpoints_json(model_name_or_path, local_rank, f, token=None): """ Dumps metadata into a JSON file for DeepSpeed-inference. """ checkpoint_files = get_checkpoint_files(model_name_or_path, local_rank, token) data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0} json.dump(data, f) f.flush() def model_on_meta(config): """ Checks if load the model to meta. """ return config.model_type in ["bloom", "llama", "falcon", "mixtral", "qwen2"] def get_optimized_model_name(config): from .transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES for model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: if model_type == config.model_type: return model_type return None def model_is_optimized(config): """ Checks if the given config belongs to a model in optimum/habana/transformers/models, which has a new input token_idx. """ return get_optimized_model_name(config) is not None def get_ds_injection_policy(config): model_type = get_optimized_model_name(config) policy = {} if model_type: if model_type == "bloom": from transformers.models.bloom.modeling_bloom import BloomBlock policy = {BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")} if model_type == "opt": from transformers.models.opt.modeling_opt import OPTDecoderLayer policy = {OPTDecoderLayer: ("self_attn.out_proj", ".fc2")} if model_type == "gpt2": from transformers.models.gpt2.modeling_gpt2 import GPT2MLP policy = {GPT2MLP: ("attn.c_proj", "mlp.c_proj")} if model_type == "gptj": from transformers.models.gptj.modeling_gptj import GPTJBlock policy = {GPTJBlock: ("attn.out_proj", "mlp.fc_out")} if model_type == "gpt_neox": from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer policy = {GPTNeoXLayer: ("attention.dense", "mlp.dense_4h_to_h")} if model_type == "llama": from transformers.models.llama.modeling_llama import LlamaDecoderLayer policy = {LlamaDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")} if model_type == "mistral": from transformers.models.mistral.modeling_mistral import MistralDecoderLayer policy = {MistralDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")} return policy