optimum/habana/checkpoint_utils.py (97 lines of code) (raw):
import json
import os
from pathlib import Path
import torch
from huggingface_hub import list_repo_files, snapshot_download
from transformers.utils import is_offline_mode
def get_repo_root(model_name_or_path, local_rank=-1, token=None):
"""
Downloads the specified model checkpoint and returns the repository where it was downloaded.
"""
if Path(model_name_or_path).is_dir():
# If it is a local model, no need to download anything
return model_name_or_path
else:
# Checks if online or not
if is_offline_mode():
if local_rank == 0:
print("Offline mode: forcing local_files_only=True")
# Only download PyTorch weights by default
if any(
".safetensors" in filename for filename in list_repo_files(model_name_or_path, token=token)
): # Some models like Falcon-180b are in only safetensors format
allow_patterns = ["*.safetensors"]
elif any(".bin" in filename for filename in list_repo_files(model_name_or_path, token=token)):
allow_patterns = ["*.bin"]
else:
raise TypeError("Only PyTorch models are supported")
# Download only on first process
if local_rank in [-1, 0]:
cache_dir = snapshot_download(
model_name_or_path,
local_files_only=is_offline_mode(),
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
allow_patterns=allow_patterns,
max_workers=16,
token=token,
)
if local_rank == -1:
# If there is only one process, then the method is finished
return cache_dir
# Make all processes wait so that other processes can get the checkpoint directly from cache
if torch.distributed.is_initialized():
torch.distributed.barrier()
return snapshot_download(
model_name_or_path,
local_files_only=is_offline_mode(),
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
allow_patterns=allow_patterns,
token=token,
)
def get_checkpoint_files(model_name_or_path, local_rank, token=None):
cached_repo_dir = get_repo_root(model_name_or_path, local_rank=local_rank, token=token)
# Extensions: .bin | .safetensors | .pt
# Creates a list of paths from all downloaded files in cache dir
from transformers import modeling_utils
if any(file.suffix == ".bin" for file in Path(cached_repo_dir).rglob("*")):
(name, ext) = os.path.splitext(modeling_utils.WEIGHTS_NAME)
elif any(file.suffix == ".safetensors" for file in Path(cached_repo_dir).rglob("*")):
(name, ext) = os.path.splitext(modeling_utils.SAFE_WEIGHTS_NAME)
else:
(name, ext) = ("*", ".pt")
file_list = [
str(entry)
for entry in Path(cached_repo_dir).rglob("*")
if (entry.is_file() and entry.name.startswith(name) and entry.name.endswith(ext))
]
return file_list
def write_checkpoints_json(model_name_or_path, local_rank, f, token=None):
"""
Dumps metadata into a JSON file for DeepSpeed-inference.
"""
checkpoint_files = get_checkpoint_files(model_name_or_path, local_rank, token)
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
json.dump(data, f)
f.flush()
def model_on_meta(config):
"""
Checks if load the model to meta.
"""
return config.model_type in ["bloom", "llama", "falcon", "mixtral", "qwen2"]
def get_optimized_model_name(config):
from .transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
for model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
if model_type == config.model_type:
return model_type
return None
def model_is_optimized(config):
"""
Checks if the given config belongs to a model in optimum/habana/transformers/models, which has a
new input token_idx.
"""
return get_optimized_model_name(config) is not None
def get_ds_injection_policy(config):
model_type = get_optimized_model_name(config)
policy = {}
if model_type:
if model_type == "bloom":
from transformers.models.bloom.modeling_bloom import BloomBlock
policy = {BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}
if model_type == "opt":
from transformers.models.opt.modeling_opt import OPTDecoderLayer
policy = {OPTDecoderLayer: ("self_attn.out_proj", ".fc2")}
if model_type == "gpt2":
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
policy = {GPT2MLP: ("attn.c_proj", "mlp.c_proj")}
if model_type == "gptj":
from transformers.models.gptj.modeling_gptj import GPTJBlock
policy = {GPTJBlock: ("attn.out_proj", "mlp.fc_out")}
if model_type == "gpt_neox":
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer
policy = {GPTNeoXLayer: ("attention.dense", "mlp.dense_4h_to_h")}
if model_type == "llama":
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
policy = {LlamaDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")}
if model_type == "mistral":
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
policy = {MistralDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")}
return policy