optimum/neuron/utils/misc.py (475 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities of various sorts.""" import copy import functools import inspect import os import re from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union import torch from transformers import PretrainedConfig from transformers.modeling_utils import _add_variant from transformers.utils import ( FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, cached_file, download_url, has_file, is_remote_url, ) from transformers.utils.hub import get_checkpoint_shard_files from ...utils import is_diffusers_available, logging from .import_utils import is_torch_neuronx_available, is_torch_xla_available from .require_utils import requires_safetensors, requires_torch_xla if is_torch_neuronx_available(): from torch_neuronx import DataParallel if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel if is_diffusers_available(): from diffusers import ModelMixin logger = logging.get_logger() def is_precompilation() -> bool: return os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY") == "1" def is_main_worker(global_main: bool = True) -> bool: if torch.distributed.is_initialized() and is_torch_xla_available(): import torch_xla.core.xla_model as xm import torch_xla.runtime as xr return xr.global_ordinal() == 0 if global_main else xm.get_local_ordinal() == 0 return True # From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse def string_to_bool(v: Union[str, bool]) -> bool: if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise TypeError( f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." ) def args_and_kwargs_to_kwargs_only( f: Callable, args: Optional[Tuple[Any, ...]] = None, kwargs: Optional[Dict[str, Any]] = None, include_default_values: bool = False, ) -> Dict[str, Any]: """ Takes a function `f`, the `args` and `kwargs` provided to the function call, and returns the save arguments in the keyword arguments format. Args: f (`Callable`): The function that is being called. args (`Optional[Tuple[Any, ...]]`, defaults to `None`): The args given to `f`. kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): The kwargs given to `f`. include_default_values (`bool`, defaults to `False`): Whether or not the return keyword arguments should contain parameters that were not in `args` and `kwargs` which have defaults values. Returns: `Dict[str, Any]`: The same arguments all formated as keyword arguments. """ if args is None: args = () if kwargs is None: kwargs = {} sig = inspect.signature(f) param_names = list(sig.parameters) result = dict(zip(param_names, args)) result.update(kwargs) if include_default_values: for param in sig.parameters.values(): if param.name in result: continue if param.default != inspect.Parameter.empty: result[param.name] = param.default return result def _original_filename_to_safetensors_filename(filename: str) -> str: """Transforms the filename for any kind of checkpoint to a safetensors equivalent.""" _, extension = filename.rsplit(".", maxsplit=1) pattern = rf"\w+(-[0-9]*-of-[0-9]*)?\.{extension}" match_ = re.match(pattern, filename) if not match_: raise ValueError(f"Could not convert {filename} to a safetensor filename.") group_1 = match_.group(1) index_out_of_total_str = group_1 if group_1 is not None else "" safetensor_filename, safetensor_extension = SAFE_WEIGHTS_NAME.rsplit(".", maxsplit=1) return f"{safetensor_filename}{index_out_of_total_str}.{safetensor_extension}" @requires_safetensors def convert_checkpoint_to_safetensors( weight_file: Union[str, Path], output_dir: Optional[Union[str, Path]] = None, safetensors_weight_filename_prefix: Optional[str] = None, log: bool = False, ) -> Path: """ Converts a PyTorch model checkpoint to a `safetensors` model checkpoint. Args: weight_file (`Union[str, Path]`): The path to the PyTorch model checkpoint. output_dir (`Optional[Union[str, Path]]`, defaults to `None`): The output directory where the `safetensors` checkpoint will be saved. If left unspecified, the parent directory of the PyTorch checkpoint will be used. safetensors_weight_filename_prefix (`Optional[str]`, defaults to `None`): If specified, the name of the converted file will be prefixed by "safetensors_weight_filename_prefix-". log (`bool`, defaults to `False`): Whether or not the function should log which file it is converting. Returns: `Path`: The path to the `safetensors` checkpoint. """ from safetensors.torch import save_file if not isinstance(weight_file, Path): weight_file = Path(weight_file) if output_dir is None: output_dir = weight_file.parent if not isinstance(output_dir, Path): output_dir = Path(output_dir) if weight_file.suffix != ".bin": raise ValueError("Can only convert PyTorch checkpoints to safetensors.") safetensors_filename = _original_filename_to_safetensors_filename(weight_file.name) if safetensors_weight_filename_prefix is not None: safetensors_filename = f"{safetensors_weight_filename_prefix}-{safetensors_filename}" safetensors_path = output_dir / safetensors_filename already_exists = safetensors_path.is_file() is_distributed = torch.distributed.is_initialized() is_main_process = is_distributed and torch.distributed.get_rank() == 0 # Only one worker will load the checkpoint (potentially huge) and perform the conversion. if not already_exists and (not is_distributed or is_main_process): if log: logger.info(f"Converting {weight_file} to safetensors") checkpoint = torch.load(weight_file, map_location=torch.device("cpu")) data_pointers = set() for k, v in checkpoint.items(): if v.data_ptr() in data_pointers: v = v.detach().clone() v = v.contiguous() checkpoint[k] = v data_pointers.add(v.data_ptr()) save_file(checkpoint, safetensors_path) del checkpoint return safetensors_path @requires_torch_xla @functools.wraps(cached_file) def distributed_friendly_cached_file(*args, **kwargs): import torch_xla.core.xla_model as xm if is_main_worker(): output = cached_file(*args, **kwargs) xm.rendezvous("Cached file done") if not is_main_worker(): output = cached_file(*args, **kwargs) return output def download_checkpoints_in_cache( pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: Optional[bool] = None, use_safetensors_in_priority: Optional[bool] = None, convert_to_safetensors: bool = False, **kwargs, ): """ Downloads checkpoint to the cache or returns the path to the already downloaded files. Note: This is a transformed version of `transformers.PreTrainedModel.from_pretrained` where only the part about downloading checkpoints has been kept. At the end of the function a custom part has been added handling the conversion to safetensors if needed. """ kwargs.pop("state_dict", None) from_tf = kwargs.pop("from_tf", False) from_flax = kwargs.pop("from_flax", False) resume_download = kwargs.pop("resume_download", None) proxies = kwargs.pop("proxies", None) kwargs.pop("output_loading_info", False) kwargs.pop("use_auth_token", None) kwargs.pop("trust_remote_code", None) _ = kwargs.pop("mirror", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) kwargs.pop("torch_dtype", None) kwargs.pop("low_cpu_mem_usage", None) kwargs.pop("device_map", None) kwargs.pop("max_memory", None) kwargs.pop("offload_folder", None) kwargs.pop("offload_state_dict", False) kwargs.pop("load_in_8bit", False) kwargs.pop("load_in_4bit", False) kwargs.pop("quantization_config", None) subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # index of the files. is_sharded = False sharded_metadata = None # Load model user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline if pretrained_model_name_or_path is not None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: if from_tf and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") ): # Load from a TF 1.0 checkpoint in priority if from_tf archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint in priority if from_tf archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) elif from_flax and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) ): # Load from a Flax checkpoint in priority if from_flax archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) elif use_safetensors is not False and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) ): # Load from a safetensors checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) ) elif use_safetensors is not False and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) ): # Load from a sharded safetensors checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) ) is_sharded = True elif use_safetensors_in_priority is not False and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) ): # Load from a safetensors checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) ) elif use_safetensors_in_priority is not False and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) ): # Load from a sharded safetensors checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) ) is_sharded = True elif os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) ): # Load from a PyTorch checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) ) elif os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) ): # Load from a sharded PyTorch checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) ) is_sharded = True # At this stage we don't have a weight file so we will raise an error. elif os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") ) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)): raise EnvironmentError( f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use" " `from_tf=True` to load this model from those weights." ) elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): raise EnvironmentError( f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`" " to load this model from those weights." ) elif use_safetensors: raise EnvironmentError( f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory" f" {pretrained_model_name_or_path}." ) else: raise EnvironmentError( f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}," f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" f" {pretrained_model_name_or_path}." ) elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): archive_file = pretrained_model_name_or_path is_local = True elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")): if not from_tf: raise ValueError( f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set " "from_tf to True to load from this checkpoint." ) archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index") is_local = True elif is_remote_url(pretrained_model_name_or_path): filename = pretrained_model_name_or_path resolved_archive_file = download_url(pretrained_model_name_or_path) else: # set correct filename if from_tf: filename = TF2_WEIGHTS_NAME elif from_flax: filename = FLAX_WEIGHTS_NAME elif use_safetensors is not False: filename = _add_variant(SAFE_WEIGHTS_NAME, variant) elif use_safetensors_in_priority is not False: filename = _add_variant(SAFE_WEIGHTS_NAME, variant) else: filename = _add_variant(WEIGHTS_NAME, variant) try: # Load from URL or cache if already cached cached_file_kwargs = { "cache_dir": cache_dir, "force_download": force_download, "proxies": proxies, "resume_download": resume_download, "local_files_only": local_files_only, "use_auth_token": token, "user_agent": user_agent, "revision": revision, "subfolder": subfolder, "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, } resolved_archive_file = distributed_friendly_cached_file( pretrained_model_name_or_path, filename, **cached_file_kwargs ) # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # result when internet is up, the repo and revision exist, but the file does not. if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = distributed_friendly_cached_file( pretrained_model_name_or_path, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), **cached_file_kwargs, ) if resolved_archive_file is not None: is_sharded = True elif use_safetensors: raise EnvironmentError( f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or " f"{_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with " "`safetensors`. Please make sure that the model has been saved with " "`safe_serialization=True` or do not set `use_safetensors=True`." ) else: # This repo has no safetensors file of any kind, we switch to PyTorch. filename = _add_variant(WEIGHTS_NAME, variant) resolved_archive_file = distributed_friendly_cached_file( pretrained_model_name_or_path, filename, **cached_file_kwargs ) if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = distributed_friendly_cached_file( pretrained_model_name_or_path, _add_variant(WEIGHTS_INDEX_NAME, variant), **cached_file_kwargs, ) if resolved_archive_file is not None: is_sharded = True if resolved_archive_file is None: # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error # message. has_file_kwargs = { "revision": revision, "proxies": proxies, "use_auth_token": token, } if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights." " Use `from_tf=True` to load this model from those weights." ) elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use" " `from_flax=True` to load this model from those weights." ) elif variant is not None and has_file( pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs ): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" f" {variant}. Use `variant=None` to load this model from those weights." ) else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" f" {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or" f" {FLAX_WEIGHTS_NAME}." ) except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted # to the original exception. raise except Exception: # For any other exception, we throw a generic error. raise EnvironmentError( f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" " from 'https://huggingface.co/models', make sure you don't have a local directory with the" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}," f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." ) if is_local: resolved_archive_file = archive_file else: resolved_archive_file = None # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. if is_sharded: # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( pretrained_model_name_or_path, resolved_archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=token, user_agent=user_agent, revision=revision, subfolder=subfolder, _commit_hash=commit_hash, ) # TODO: this whole bulk is not very optimized, improve it once the tests are written. if convert_to_safetensors: maybe_to_convert = resolved_archive_file if not isinstance(maybe_to_convert, list): maybe_to_convert = [maybe_to_convert] filenames_to_safetensors_filenames = {} for filename in maybe_to_convert: filename = Path(filename) if filename.suffix == ".safetensors": filenames_to_safetensors_filenames[filename.name] = filename elif filename.suffix == ".bin": output_path = convert_checkpoint_to_safetensors( filename, safetensors_weight_filename_prefix="converted", log=True ) filenames_to_safetensors_filenames[filename.name] = output_path else: raise ValueError("Only PyTorch and safetensors files are supported.") if sharded_metadata is not None: weight_map = sharded_metadata["weight_map"] for weight_name, torch_filename in weight_map.items(): weight_map[weight_name] = filenames_to_safetensors_filenames[torch_filename] if isinstance(resolved_archive_file, list): resolved_archive_file = [ filenames_to_safetensors_filenames[Path(filename).name] for filename in resolved_archive_file ] else: resolved_archive_file = filenames_to_safetensors_filenames[Path(resolved_archive_file).name] return resolved_archive_file, sharded_metadata def replace_weights( model: Union[torch.jit._script.RecursiveScriptModule, "DataParallel"], weights: Union[Dict[str, torch.Tensor], torch.nn.Module], prefix: str = "model", ): """ Replaces the weights in a Neuron Model with weights from another model, the original neuron model should have separated weights(by setting `inline_weights_to_neff=False` during the tracing). """ if isinstance(weights, torch.nn.Module): weights = weights.state_dict() # extract module paths from the weights c module if is_torch_neuronx_available() and isinstance(model, DataParallel): model_weights = model.module.weights else: model_weights = model.weights code = model_weights._c.code start_str = "__parameters__ = [" end_str = "]\n" module_paths = code.split(start_str)[1].split(end_str)[0].strip()[:-1:].replace('"', "").split(", ") module_paths = [module_path for module_path in module_paths if module_path != ""] for module_path in module_paths: if len(re.findall("\w\d+", module_path)) > 0: continue else: model_weights._c.setattr( module_path, weights[module_path.replace(prefix + "->", "", 1).replace("->", ".")] ) def check_if_weights_replacable( config: Union["PretrainedConfig", Dict[str, "PretrainedConfig"]], weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]], ): def _is_weights_neff_separated(config): return not config.neuron.get("inline_weights_to_neff", True) if hasattr(config, "neuron") else False if isinstance(config, PretrainedConfig): is_weights_neff_separated = _is_weights_neff_separated(config) elif isinstance(config, Dict): is_weights_neff_separated = [] for _, config_value in config.items(): is_weights_neff_separated.append(_is_weights_neff_separated(config_value)) is_weights_neff_separated = all(is_weights_neff_separated) if weights is not None and not is_weights_neff_separated: raise RuntimeError( "Unable to replace weights of the neuron model since its weights and neff are not separated, please set `inline_weights_to_neff=False` when converting the model to Neuron format." ) class DiffusersPretrainedConfig(PretrainedConfig): """override to update `model_type`.""" def to_dict(self): """ Serializes this instance to a Python dictionary. Returns: :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance. """ output = copy.deepcopy(self.__dict__) return output def get_stable_diffusion_configs( models_for_export: Dict[str, Union["PreTrainedModel", "ModelMixin"]], ): subfolders = ["text_encoder", "text_encoder_2", "unet", "vae"] configs = {} for name in subfolders: if name in models_for_export: configs[name] = models_for_export[name].config return configs def map_torch_dtype(dtype: Union[str, torch.dtype]): dtype_mapping = { "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32, "float64": torch.float64, "int32": torch.int32, "int64": torch.int64, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32, } if isinstance(dtype, str) and dtype in dtype_mapping: dtype = dtype_mapping.get(dtype) return dtype