optimum/exporters/neuron/convert.py (490 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neuron compiled model check and export functions."""
import copy
import time
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import PreTrainedModel
from ...exporters.error_utils import OutputMatchError, ShapeError
from ...neuron.cache.entries.multi_model import MultiModelCacheEntry
from ...neuron.cache.entries.single_model import SingleModelCacheEntry
from ...neuron.cache.traced import cache_traced_neuron_artifacts
from ...neuron.utils import (
DiffusersPretrainedConfig,
convert_neuronx_compiler_args_to_neuron,
is_neuron_available,
is_neuronx_available,
is_neuronx_distributed_available,
store_compilation_config,
)
from ...neuron.utils.cache_utils import get_model_name_or_path
from ...neuron.utils.version_utils import get_neuroncc_version, get_neuronxcc_version
from ...utils import (
is_diffusers_available,
is_sentence_transformers_available,
logging,
)
if TYPE_CHECKING:
from .base import NeuronDefaultConfig
if is_neuron_available():
import torch.neuron as neuron # noqa: F811
NEURON_COMPILER_TYPE = "neuron-cc"
NEURON_COMPILER_VERSION = get_neuroncc_version()
if is_neuronx_available():
import torch_neuronx as neuronx # noqa: F811
NEURON_COMPILER_TYPE = "neuronx-cc"
NEURON_COMPILER_VERSION = get_neuronxcc_version()
if is_diffusers_available():
from diffusers import ModelMixin
from diffusers.configuration_utils import FrozenDict
if is_sentence_transformers_available():
from sentence_transformers import SentenceTransformer
if is_neuronx_distributed_available():
import neuronx_distributed
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def validate_models_outputs(
models_and_neuron_configs: Dict[
str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], "NeuronDefaultConfig"]
],
neuron_named_outputs: Dict[str, List[str]],
output_dir: Path,
atol: Optional[float] = None,
neuron_files_subpaths: Optional[Dict[str, str]] = None,
):
"""
Validates the export of several models, by checking that the outputs from both the reference and the exported model match.
The following method validates the Neuron models exported using the `export_models` method.
Args:
models_and_neuron_configs (`Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`, `torch.nn.Module`], `NeuronDefaultConfig`]]):
A dictionnary containing the models to export and their corresponding neuron configs.
neuron_named_outputs (`List[List[str]]`):
The names of the outputs to check.
output_dir (`Path`):
Output directory where the exported Neuron models are stored.
atol (`Optional[float]`, defaults to `None`):
The absolute tolerance in terms of outputs difference between the reference and the exported model.
neuron_files_subpaths (`Optional[List[str]]`, defaults to `None`):
The relative paths from `output_dir` to the Neuron files to do validation on. The order must be the same as the order of submodels
in the ordered dict `models_and_neuron_configs`. If None, will use the keys from the `models_and_neuron_configs` as names.
Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
"""
if len(neuron_named_outputs) != len(models_and_neuron_configs.keys()):
raise ValueError(
f"Invalid number of Neuron named outputs. Required {models_and_neuron_configs.keys()}, Provided {neuron_named_outputs.keys()}"
)
if neuron_named_outputs is not None and len(neuron_named_outputs) != len(models_and_neuron_configs):
raise ValueError(
f"Provided custom names {neuron_files_subpaths} for the validation of {len(models_and_neuron_configs)} models. Please provide the same number of Neuron file names as models to export."
)
exceptions = [] # run all validations before raising
neuron_paths = []
for i, model_name in enumerate(models_and_neuron_configs.keys()):
submodel, sub_neuron_config = models_and_neuron_configs[model_name]
ref_submodel = copy.deepcopy(submodel)
neuron_model_path = (
output_dir.joinpath(neuron_files_subpaths[model_name])
if neuron_files_subpaths is not None
else output_dir.joinpath(model_name + ".neuron")
)
neuron_paths.append(neuron_model_path)
try:
logger.info(f"Validating {model_name} model...")
validate_model_outputs(
config=sub_neuron_config,
reference_model=ref_submodel,
neuron_model_path=neuron_model_path,
neuron_named_outputs=neuron_named_outputs[model_name],
atol=atol,
)
except Exception as e:
exceptions.append(f"Validation of {model_name} fails: {e}")
if len(exceptions) != 0:
for i, exception in enumerate(exceptions[:-1]):
logger.error(f"Validation {i} for the model {neuron_paths[i].as_posix()} raised: {exception}")
raise Exception(exceptions[-1])
def validate_model_outputs(
config: "NeuronDefaultConfig",
reference_model: Union["PreTrainedModel", "SentenceTransformer", "ModelMixin"],
neuron_model_path: Path,
neuron_named_outputs: List[str],
atol: Optional[float] = None,
):
"""
Validates the export by checking that the outputs from both the reference and the exported model match.
Args:
config ([`~optimum.neuron.exporter.NeuronDefaultConfig`]:
The configuration used to export the model.
reference_model ([`Union["PreTrainedModel", "SentenceTransformer", "ModelMixin"]`]):
The model used for the export.
neuron_model_path (`Path`):
The path to the exported model.
neuron_named_outputs (`List[str]`):
The names of the outputs to check.
atol (`Optional[float]`, defaults to `None`):
The absolute tolerance in terms of outputs difference between the reference and the exported model.
Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
"""
if atol is None:
if isinstance(config.ATOL_FOR_VALIDATION, dict):
atol = config.ATOL_FOR_VALIDATION[config.task]
else:
atol = config.ATOL_FOR_VALIDATION
input_shapes = {}
for axis in config.mandatory_axes:
input_shapes[axis] = getattr(config, axis)
if config.dynamic_batch_size is True and "batch_size" in input_shapes:
input_shapes["batch_size"] *= 2
# Reference outputs
with torch.no_grad():
reference_model.eval()
inputs = config.generate_dummy_inputs(return_tuple=False, **input_shapes)
ref_inputs = config.unflatten_inputs(inputs)
if hasattr(reference_model, "config") and getattr(reference_model.config, "is_encoder_decoder", False):
reference_model = config.patch_model_for_export(reference_model, device="cpu", **input_shapes)
if "SentenceTransformer" in reference_model.__class__.__name__:
reference_model = config.patch_model_for_export(reference_model, ref_inputs)
ref_outputs = reference_model(**ref_inputs)
neuron_inputs = tuple(config.flatten_inputs(inputs).values())
elif "AutoencoderKL" in getattr(config._config, "_class_name", "") or getattr(
config._config, "is_encoder_decoder", False
):
# VAE components for stable diffusion or Encoder-Decoder models
ref_inputs = tuple(ref_inputs.values())
ref_outputs = reference_model(*ref_inputs)
neuron_inputs = tuple(inputs.values())
elif config.CUSTOM_MODEL_WRAPPER is not None:
ref_inputs = config.flatten_inputs(inputs)
reference_model = config.patch_model_for_export(reference_model, ref_inputs, device="cpu")
neuron_inputs = ref_inputs = tuple(ref_inputs.values())
ref_outputs = reference_model(*ref_inputs)
else:
ref_outputs = reference_model(**ref_inputs)
neuron_inputs = tuple(config.flatten_inputs(inputs).values())
# Neuron outputs
neuron_model = torch.jit.load(neuron_model_path)
neuron_outputs = neuron_model(*neuron_inputs)
if isinstance(neuron_outputs, dict):
neuron_outputs = tuple(neuron_outputs.values())
elif isinstance(neuron_outputs, torch.Tensor):
neuron_outputs = (neuron_outputs,)
# Check if we have a subset of the keys into neuron_outputs against ref_outputs
neuron_output_names_set = set(neuron_named_outputs)
neuron_output_names_list = sorted(neuron_output_names_set, key=neuron_named_outputs.index)
if isinstance(ref_outputs, dict):
ref_output_names_set = set(ref_outputs.keys())
if not neuron_output_names_set.issubset(ref_output_names_set):
raise OutputMatchError(
"Neuron model output names do not match reference model output names.\n"
f"Reference model output names: {ref_output_names_set}\n"
f"Neuron model output names: {neuron_output_names_set}\n"
f"Difference: {neuron_output_names_set.difference(ref_output_names_set)}"
)
else:
neuron_output_names = ", ".join(neuron_output_names_set)
logger.info(f"\t-[✓] Neuron model output names match reference model ({neuron_output_names})")
# folowing are cases for diffusers
elif isinstance(ref_outputs, torch.Tensor):
ref_outputs = {neuron_named_outputs[0]: ref_outputs}
elif isinstance(ref_outputs, tuple):
ref_outputs = dict(zip(neuron_named_outputs, ref_outputs))
# Check if the number of outputs matches the number of output names
if len(neuron_output_names_set) != len(neuron_outputs):
raise OutputMatchError(
f"The exported Neuron model has {len(neuron_outputs)} outputs while {len(neuron_output_names_set)} are expected."
)
# Check the shape and values match
shape_failures = []
value_failures = []
for i, (name, neuron_output) in enumerate(zip(neuron_output_names_list, neuron_outputs)):
if isinstance(neuron_output, torch.Tensor):
ref_output = ref_outputs[name] if isinstance(ref_outputs, dict) else ref_outputs[i]
neuron_output = neuron_output
elif isinstance(neuron_output, tuple): # eg. `hidden_states` of `AutoencoderKL` is a tuple of tensors;
ref_output = torch.stack(ref_outputs[name])
neuron_output = torch.stack(neuron_output)
elif isinstance(neuron_output, list):
ref_output = ref_outputs[name]
neuron_output = neuron_output
logger.info(f'\t- Validating Neuron Model output "{name}":')
# Shape
output_list = (
neuron_output if isinstance(neuron_output, list) else [neuron_output]
) # eg. `down_block_res_samples` of `ControlNet` is a list of tensors.
ref_output_list = ref_output if isinstance(ref_output, list) else [ref_output]
for output, ref_output in zip(output_list, ref_output_list):
if not output.shape == ref_output.shape:
logger.error(f"\t\t-[x] shape {output.shape} doesn't match {ref_output.shape}")
shape_failures.append((name, ref_output.shape, output.shape))
else:
logger.info(f"\t\t-[✓] {output.shape} matches {ref_output.shape}")
# Values
if not torch.allclose(ref_output, output.to(ref_output.dtype), atol=atol):
max_diff = torch.max(torch.abs(ref_output - output))
logger.error(f"\t\t-[x] values not close enough, max diff: {max_diff} (atol: {atol})")
value_failures.append((name, max_diff))
else:
logger.info(f"\t\t-[✓] all values close (atol: {atol})")
if shape_failures:
msg = "\n".join(f"- {t[0]}: got {t[1]} (reference) and {t[2]} (neuron)" for t in shape_failures)
raise ShapeError("Output shapes do not match between reference model and the Neuron exported model:\n{msg}")
if value_failures:
msg = "\n".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures)
logger.warning(
"The maximum absolute difference between the output of the reference model and the Neuron "
f"exported model is not within the set tolerance {atol}:\n{msg}"
)
def export_models(
models_and_neuron_configs: Dict[
str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], "NeuronDefaultConfig"]
],
task: str,
output_dir: Path,
disable_neuron_cache: Optional[bool] = False,
compiler_workdir: Optional[Path] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
output_file_names: Optional[Dict[str, str]] = None,
compiler_kwargs: Optional[Dict[str, Any]] = {},
model_name_or_path: Optional[str] = None,
) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
"""
Exports a Pytorch model with multiple component models to separate files.
Args:
models_and_neuron_configs (`Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], `NeuronDefaultConfig`]]):
A dictionnary containing the models to export and their corresponding neuron configs.
task (`str`):
The task for which the models should be exported.
output_dir (`Path`):
Output directory to store the exported Neuron models.
disable_neuron_cache (`Optional[bool]`, defaults to `False`):
Whether to disable automatic caching of AOT compiled models (not applicable for JIT compilation).
compiler_workdir (`Optional[Path]`, defaults to `None`):
The directory to store intermediary outputs of the neuron compiler.
inline_weights_to_neff (`bool`, defaults to `True`):
Whether to inline the weights to the neff graph. If set to False, weights will be separated from the neff.
optlevel (`str`, defaults to `"2"`):
The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2".
1: enables the core performance optimizations in the compiler, while also minimizing compile time.
2: provides the best balance between model performance and compile time.
3: may provide additional model execution performance but may incur longer compile times and higher host memory usage during model compilation.
output_file_names (`Optional[List[str]]`, defaults to `None`):
The names to use for the exported Neuron files. The order must be the same as the order of submodels in the ordered dict `models_and_neuron_configs`.
If None, will use the keys from `models_and_neuron_configs` as names.
compiler_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Arguments to pass to the Neuron(x) compiler for exporting Neuron models.
model_name_or_path (`Optional[str]`, defaults to `None`):
Path to pretrained model or model identifier from the Hugging Face Hub.
Returns:
`Tuple[Dict[str, List[str]], Dict[str, List[str]]]`: A tuple with two dictionaries containing ordered list of the model's inputs, and the named
outputs from the Neuron configuration.
"""
all_inputs = {}
all_outputs = {}
if compiler_workdir is not None:
compiler_workdir = Path(compiler_workdir)
if output_file_names is not None and len(output_file_names) != len(models_and_neuron_configs):
raise ValueError(
f"Provided {len(output_file_names)} custom names for the export of {len(models_and_neuron_configs)} models. Please provide the same number of names as models to export."
)
failed_models = []
total_compilation_time = 0
compile_configs = {}
for i, model_name in enumerate(models_and_neuron_configs.keys()):
logger.info(f"***** Compiling {model_name} *****")
submodel, sub_neuron_config = models_and_neuron_configs[model_name]
output_file_name = (
output_file_names[model_name] if output_file_names is not None else Path(model_name + ".neuron")
)
output_path = output_dir / output_file_name
output_path.parent.mkdir(parents=True, exist_ok=True)
# TODO: Remove after the weights/neff separation compilation of sdxl is patched by a neuron sdk release: https://github.com/aws-neuron/aws-neuron-sdk/issues/859
if not inline_weights_to_neff and getattr(sub_neuron_config, "is_sdxl", False):
logger.warning(
"The compilation of SDXL's unet with the weights/neff separation is broken since the Neuron SDK 2.18 release. `inline_weights_to_neff` will be set to True and the caching will be disabled. If you still want to separate the neff and weights, please downgrade your Neuron setup to the 2.17.1 release."
)
inline_weights_to_neff = True
start_time = time.time()
neuron_inputs, neuron_outputs = export(
model_or_path=submodel,
config=sub_neuron_config,
output=output_path,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
**compiler_kwargs,
)
compilation_time = time.time() - start_time
total_compilation_time += compilation_time
logger.info(f"[Compilation Time] {np.round(compilation_time, 2)} seconds.")
all_inputs[model_name] = neuron_inputs
all_outputs[model_name] = neuron_outputs
# Add neuron specific configs to model components' original config
model_config = sub_neuron_config._config
if is_diffusers_available() and isinstance(model_config, FrozenDict):
model_config = OrderedDict(model_config)
model_config = DiffusersPretrainedConfig.from_dict(model_config)
model_config = store_compilation_config(
config=model_config,
input_shapes=sub_neuron_config.input_shapes,
compiler_kwargs=compiler_kwargs,
input_names=neuron_inputs,
output_names=neuron_outputs,
dynamic_batch_size=sub_neuron_config.dynamic_batch_size,
tensor_parallel_size=sub_neuron_config.tensor_parallel_size,
compiler_type=NEURON_COMPILER_TYPE,
compiler_version=NEURON_COMPILER_VERSION,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
model_type=getattr(sub_neuron_config, "MODEL_TYPE", None),
task=getattr(sub_neuron_config, "task", None),
output_attentions=getattr(sub_neuron_config, "output_attentions", False),
output_hidden_states=getattr(sub_neuron_config, "output_hidden_states", False),
)
model_config.save_pretrained(output_path.parent)
compile_configs[model_name] = model_config
logger.info(f"[Total compilation Time] {np.round(total_compilation_time, 2)} seconds.")
# cache neuronx model
if not disable_neuron_cache and is_neuronx_available():
model_id = get_model_name_or_path(model_config) if model_name_or_path is None else model_name_or_path
if len(compile_configs) == 1:
# FIXME: this is overly complicated just to pass the config
cache_config = list(compile_configs.values())[0]
cache_entry = SingleModelCacheEntry(model_id=model_id, task=task, config=cache_config)
else:
cache_entry = MultiModelCacheEntry(model_id=model_id, configs=compile_configs)
cache_traced_neuron_artifacts(neuron_dir=output_dir, cache_entry=cache_entry)
# remove models failed to export
for i, model_name in failed_models:
output_file_names.pop(model_name)
models_and_neuron_configs.pop(model_name)
return all_inputs, all_outputs
def export(
model_or_path: Union["PreTrainedModel", str, Path],
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
auto_cast: Optional[str] = None,
auto_cast_type: str = "bf16",
disable_fast_relayout: bool = False,
disable_fallback: bool = False,
) -> Tuple[List[str], List[str]]:
if is_neuron_available():
return export_neuron(
model=model_or_path,
config=config,
output=output,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
auto_cast=auto_cast,
auto_cast_type=auto_cast_type,
disable_fast_relayout=disable_fast_relayout,
disable_fallback=disable_fallback,
)
elif is_neuronx_available():
return export_neuronx(
model_or_path=model_or_path,
config=config,
output=output,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
auto_cast=auto_cast,
auto_cast_type=auto_cast_type,
)
else:
raise RuntimeError(
"Cannot export the model because the neuron(x) compiler is not installed. See https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html."
)
def export_neuronx(
model_or_path: Union["PreTrainedModel", str, Path],
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
auto_cast: Optional[str] = None,
auto_cast_type: str = "bf16",
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to a serialized TorchScript module compiled by neuronx-cc compiler.
Args:
model_or_path (Union["PreTrainedModel", str, Path]):
The model to export or its location(case when applying the parallelism as the model needs to be loaded with the tracing).
config ([`~exporter.NeuronDefaultConfig`]):
The Neuron configuration associated with the exported model.
output (`Path`):
Directory to store the exported Neuron model.
compiler_workdir (`Optional[Path]`, defaults to `None`):
The directory used by neuronx-cc, where you can find intermediary outputs (neff, weight, hlo...).
inline_weights_to_neff (`bool`, defaults to `True`):
Whether to inline the weights to the neff graph. If set to False, weights will be separated from the neff.
optlevel (`str`, defaults to `"2"`):
The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2".
1: enables the core performance optimizations in the compiler, while also minimizing compile time.
2: provides the best balance between model performance and compile time.
3: may provide additional model execution performance but may incur longer compile times and higher host memory usage during model compilation.
auto_cast (`Optional[str]`, defaults to `None`):
Whether to cast operations from FP32 to lower precision to speed up the inference. Can be `None`, `"matmul"` or `"all"`, you should use `None` to disable any auto-casting, use `"matmul"` to cast FP32 matrix multiplication operations, and use `"all"` to cast all FP32 operations.
auto_cast_type (`str`, defaults to `"bf16"`):
The data type to cast FP32 operations to when auto-cast mode is enabled. Can be `"bf16"`, `"fp16"` or `"tf32"`.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
the Neuron configuration.
"""
output.parent.mkdir(parents=True, exist_ok=True)
if isinstance(compiler_workdir, Path):
compiler_workdir = compiler_workdir.as_posix()
if hasattr(model_or_path, "config"):
model_or_path.config.return_dict = True
model_or_path.config.torchscript = True
if isinstance(model_or_path, PreTrainedModel):
model_or_path.eval()
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
if isinstance(model_or_path, PreTrainedModel):
setattr(model_or_path.config, override_config_key, override_config_value)
# Prepare dummy inputs for tracing
input_shapes = {}
for axis in config.mandatory_axes:
input_shapes[axis] = getattr(config, axis)
dummy_inputs = config.generate_dummy_inputs(**input_shapes)
dummy_inputs = config.flatten_inputs(dummy_inputs)
dummy_inputs_tuple = tuple(dummy_inputs.values())
# Prepare the model / function(tp) to trace
aliases = {}
tensor_parallel_size = config.tensor_parallel_size
if getattr(config, "is_encoder_decoder", False):
checked_model = config.patch_model_for_export(model_or_path, **input_shapes)
if tensor_parallel_size == 1 and hasattr(config, "generate_io_aliases"):
aliases = config.generate_io_aliases(checked_model)
else:
checked_model = config.patch_model_for_export(model_or_path, dummy_inputs)
# Construct compiler configurations
if auto_cast is not None:
logger.info(f"Using Neuron: --auto-cast {auto_cast}")
auto_cast = "matmult" if auto_cast == "matmul" else auto_cast
compiler_args = ["--auto-cast", auto_cast]
logger.info(f"Using Neuron: --auto-cast-type {auto_cast_type}")
compiler_args.extend(["--auto-cast-type", auto_cast_type])
else:
compiler_args = ["--auto-cast", "none"]
compiler_args.extend(["--optlevel", optlevel])
logger.info(f"Using Neuron: --optlevel {optlevel}")
# no idea what range of models this flag could be applied, here are some exceptions that we have observed so far.
excluded_models = {
"unet",
"vae-encoder",
"vae-decoder",
"hubert",
"levit",
"mobilenet-v2",
"mobilevit",
"unispeech",
"unispeech-sat",
"wav2vec2",
"wavlm",
}
if config.MODEL_TYPE not in excluded_models:
compiler_args.extend(["--model-type", "transformer"])
compiler_args = add_stable_diffusion_compiler_args(config, compiler_args) # diffusers specific
if config.dynamic_batch_size and not inline_weights_to_neff:
logger.warning(
"Dynamic batching is not yet compatible with the weights/neff non-inlined model. `inline_weights_to_neff` is set to True. If you still want to separate the neff and weights, please set `dynamic_batch_size=False`."
)
inline_weights_to_neff = True
# Start trace
if tensor_parallel_size > 1:
# 1. use NxD to trace for parallel
neuron_model = neuronx_distributed.trace.parallel_model_trace(
checked_model,
dummy_inputs_tuple,
compiler_args=compiler_args,
inline_weights_to_neff=inline_weights_to_neff,
compiler_workdir=compiler_workdir,
tp_degree=tensor_parallel_size,
)
neuronx_distributed.trace.parallel_model_save(neuron_model, output)
else:
# 2. use `torch_neuronx.trace`
neuron_model = neuronx.trace(
checked_model,
dummy_inputs_tuple,
compiler_args=compiler_args,
input_output_aliases=aliases,
inline_weights_to_neff=inline_weights_to_neff,
compiler_workdir=compiler_workdir,
)
if config.dynamic_batch_size is True:
neuron_model = neuronx.dynamic_batch(neuron_model)
# diffusers specific
improve_stable_diffusion_loading(config, neuron_model)
torch.jit.save(neuron_model, output)
del model_or_path
del checked_model
del dummy_inputs
del neuron_model
return config.inputs, config.outputs
def add_stable_diffusion_compiler_args(config, compiler_args):
# Combine the model name and its path to identify which is the subcomponent in Stable Diffusion pipeline
identifier = getattr(config._config, "_name_or_path", "") + " " + getattr(config._config, "_class_name", "")
identifier = identifier.lower()
sd_components = ["text_encoder", "vae", "vae_encoder", "vae_decoder", "controlnet"]
if any(component in identifier for component in sd_components):
compiler_args.append("--enable-fast-loading-neuron-binaries")
# unet or transformer or controlnet
if any(model_type in identifier for model_type in ["unet", "transformer", "controlnet"]):
# SDXL unet doesn't support fast loading neuron binaries(sdk 2.19.1)
if not getattr(config, "is_sdxl", False):
compiler_args.append("--enable-fast-loading-neuron-binaries")
if "unet" in identifier or "controlnet" in identifier:
compiler_args.append("--model-type=unet-inference")
return compiler_args
def improve_stable_diffusion_loading(config, neuron_model):
# Combine the model name and its path to identify which is the subcomponent in Diffusion pipeline
identifier = getattr(config._config, "_name_or_path", "") + " " + getattr(config._config, "_class_name", "")
identifier = identifier.lower()
sd_components = ["text_encoder", "unet", "transformer", "vae", "vae_encoder", "vae_decoder", "controlnet"]
if any(component in identifier for component in sd_components):
neuronx.async_load(neuron_model)
# unet
if any(model_type in identifier for model_type in ["unet", "transformer", "controlnet"]):
neuronx.lazy_load(neuron_model)
def export_neuron(
model: "PreTrainedModel",
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
inline_weights_to_neff: bool = True,
auto_cast: Optional[str] = None,
auto_cast_type: str = "bf16",
disable_fast_relayout: bool = False,
disable_fallback: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to a serialized TorchScript module compiled by neuron-cc compiler.
Args:
model ([`PreTrainedModel`]):
The model to export.
config ([`~exporter.NeuronDefaultConfig`]):
The Neuron configuration associated with the exported model.
output (`Path`):
Directory to store the exported Neuron model.
compiler_workdir (`Optional[Path]`, defaults to `None`):
The directory used by neuron-cc, where you can find intermediary outputs (neff, weight, hlo...).
inline_weights_to_neff (`bool`, defaults to `True`):
Whether to inline the weights to the neff graph. If set to False, weights will be separated from the neff.
auto_cast (`Optional[str]`, defaults to `None`):
Whether to cast operations from FP32 to lower precision to speed up the inference. Can be `None`, `"matmul"` or `"all"`, you should use `None` to disable any auto-casting, use `"matmul"` to cast FP32 matrix multiplication operations, and use `"all"` to cast all FP32 operations.
auto_cast_type (`str`, defaults to `"bf16"`):
The data type to cast FP32 operations to when auto-cast mode is enabled. Can be `"bf16"`, `"fp16"`, ``"mixed" or `"tf32"`. `"mixed"` is only available when auto_cast is "matmul", it will cast operators that use Neuron Matmult engine to bf16 while using fp16 for matmult-based transpose.
disable_fast_relayout (`bool`, defaults to `False`):
Whether to disable fast relayout optimization which improves performance by using the matrix multiplier for tensor transpose.
disable_fallback (`bool`, defaults to `False`):
Whether to disable CPU partitioning to force operations to Neuron. Defaults to `False`, as without fallback, there could be some compilation failures or performance problems.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
the Neuron configuration.
"""
output.parent.mkdir(parents=True, exist_ok=True)
if isinstance(compiler_workdir, Path):
compiler_workdir = compiler_workdir.as_posix()
if hasattr(model, "config"):
model.config.return_dict = True
model.config.torchscript = True
model.eval()
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
input_shapes = {}
for axis in config.mandatory_axes:
input_shapes[axis] = getattr(config, axis)
dummy_inputs = config.generate_dummy_inputs(**input_shapes)
dummy_inputs_tuple = tuple(dummy_inputs.values())
checked_model = config.patch_model_for_export(model, dummy_inputs)
compiler_args = convert_neuronx_compiler_args_to_neuron(auto_cast, auto_cast_type, disable_fast_relayout)
if config.dynamic_batch_size is True and not inline_weights_to_neff:
logger.warning(
"Dynamic batching is not yet compatible with the weights/neff non-inlined model. `inline_weights_to_neff` is set to True. If you still want to separate the neff and weights, please set `dynamic_batch_size=False`."
)
inline_weights_to_neff = True
neuron_model = neuron.trace(
checked_model,
dummy_inputs_tuple,
dynamic_batch_size=config.dynamic_batch_size,
compiler_args=compiler_args,
compiler_workdir=compiler_workdir,
separate_weights=not inline_weights_to_neff,
fallback=not disable_fallback,
)
torch.jit.save(neuron_model, output)
del model
del checked_model
del dummy_inputs
del neuron_model
return config.inputs, config.outputs