optimum/exporters/neuron/__main__.py (685 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Entry point to the optimum.exporters.neuron command line."""
import argparse
import inspect
import os
os.environ["TORCHDYNAMO_DISABLE"] = "1" # Always turn off torchdynamo as it's incompatible with neuron
from argparse import ArgumentParser
from dataclasses import fields
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoConfig, AutoTokenizer, PretrainedConfig
from optimum.exporters.error_utils import AtolError, OutputMatchError, ShapeError
from optimum.exporters.tasks import TasksManager
from optimum.utils import is_diffusers_available, logging
from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from ...neuron.models.auto_model import get_neuron_model_class, has_neuron_model_class
from ...neuron.utils import (
DECODER_NAME,
DIFFUSION_MODEL_CONTROLNET_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_NAME,
DIFFUSION_MODEL_TRANSFORMER_NAME,
DIFFUSION_MODEL_UNET_NAME,
DIFFUSION_MODEL_VAE_DECODER_NAME,
DIFFUSION_MODEL_VAE_ENCODER_NAME,
ENCODER_NAME,
NEURON_FILE_NAME,
ImageEncoderArguments,
InputShapesArguments,
IPAdapterArguments,
LoRAAdapterArguments,
is_neuron_available,
is_neuronx_available,
map_torch_dtype,
)
from ...neuron.utils.version_utils import (
check_compiler_compatibility_for_stable_diffusion,
)
from .base import NeuronExportConfig
from .convert import export_models, validate_models_outputs
from .model_configs import * # noqa: F403
from .utils import (
build_stable_diffusion_components_mandatory_shapes,
check_mandatory_input_shapes,
get_diffusion_models_for_export,
get_encoder_decoder_models_for_export,
replace_stable_diffusion_submodels,
)
if is_neuron_available():
from ...commands.export.neuron import parse_args_neuron
NEURON_COMPILER = "Neuron"
if is_neuronx_available():
from ...commands.export.neuronx import parse_args_neuronx as parse_args_neuron # noqa: F811
NEURON_COMPILER = "Neuronx"
if is_diffusers_available():
from diffusers import StableDiffusionXLPipeline
if TYPE_CHECKING:
from transformers import PreTrainedModel
if is_diffusers_available():
from diffusers import DiffusionPipeline, ModelMixin, StableDiffusionPipeline
logger = logging.get_logger()
logger.setLevel(logging.INFO)
def infer_compiler_kwargs(args: argparse.Namespace) -> Dict[str, Any]:
# infer compiler kwargs
auto_cast = None if args.auto_cast == "none" else args.auto_cast
auto_cast_type = None if auto_cast is None else args.auto_cast_type
compiler_kwargs = {"auto_cast": auto_cast, "auto_cast_type": auto_cast_type}
if hasattr(args, "disable_fast_relayout"):
compiler_kwargs["disable_fast_relayout"] = getattr(args, "disable_fast_relayout")
if hasattr(args, "disable_fallback"):
compiler_kwargs["disable_fallback"] = getattr(args, "disable_fallback")
return compiler_kwargs
def infer_task(model_name_or_path: str) -> str:
try:
return TasksManager.infer_task_from_model(model_name_or_path)
except KeyError as e:
raise KeyError(
"The task could not be automatically inferred. Please provide the argument --task with the task "
f"from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)
except RequestsConnectionError as e:
raise RequestsConnectionError(
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)
# This function is not applicable for diffusers / sentence transformers models
def get_input_shapes(task: str, args: argparse.Namespace) -> Dict[str, int]:
neuron_config_constructor = get_neuron_config_class(task, args.model)
input_args = neuron_config_constructor.func.get_input_args_for_task(task)
return {name: getattr(args, name) for name in input_args}
def get_neuron_config_class(task: str, model_id: str) -> NeuronExportConfig:
config = AutoConfig.from_pretrained(model_id)
model_type = config.model_type.replace("_", "-")
if config.is_encoder_decoder:
model_type = model_type + "-encoder"
neuron_config_constructor = TasksManager.get_exporter_config_constructor(
model_type=model_type,
exporter="neuron",
task=task,
library_name="transformers",
)
return neuron_config_constructor
def normalize_sentence_transformers_input_shapes(args: argparse.Namespace) -> Dict[str, int]:
args = vars(args) if isinstance(args, argparse.Namespace) else args
if "clip" in args.get("model", "").lower():
mandatory_axes = {"text_batch_size", "image_batch_size", "sequence_length", "num_channels", "width", "height"}
else:
mandatory_axes = {"batch_size", "sequence_length"}
if not mandatory_axes.issubset(set(args.keys())):
raise AttributeError(
f"Shape of {mandatory_axes} are mandatory for neuron compilation, while {mandatory_axes.difference(args.keys())} are not given."
)
mandatory_shapes = {name: args[name] for name in mandatory_axes}
return mandatory_shapes
def customize_optional_outputs(args: argparse.Namespace) -> Dict[str, bool]:
"""
Customize optional outputs of the traced model, eg. if `output_attentions=True`, the attentions tensors will be traced.
"""
possible_outputs = ["output_attentions", "output_hidden_states"]
customized_outputs = {}
for name in possible_outputs:
customized_outputs[name] = getattr(args, name, False)
return customized_outputs
def parse_optlevel(args: argparse.Namespace) -> Dict[str, bool]:
"""
(NEURONX ONLY) Parse the level of optimization the compiler should perform. If not specified apply `O2`(the best balance between model performance and compile time).
"""
if is_neuronx_available():
if args.O1:
optlevel = "1"
elif args.O2:
optlevel = "2"
elif args.O3:
optlevel = "3"
else:
optlevel = "2"
else:
optlevel = None
return optlevel
def normalize_stable_diffusion_input_shapes(
args: argparse.Namespace,
) -> Dict[str, Dict[str, int]]:
args = vars(args) if isinstance(args, argparse.Namespace) else args
mandatory_axes = set(getattr(inspect.getfullargspec(build_stable_diffusion_components_mandatory_shapes), "args"))
mandatory_axes = mandatory_axes - {
"sequence_length", # `sequence_length` is optional, diffusers will pad it to the max if not provided.
# remove number of channels.
"unet_or_transformer_num_channels",
"vae_encoder_num_channels",
"vae_decoder_num_channels",
"num_images_per_prompt", # default to 1
}
if not mandatory_axes.issubset(set(args.keys())):
raise AttributeError(
f"Shape of {mandatory_axes} are mandatory for neuron compilation, while {mandatory_axes.difference(args.keys())} are not given."
)
mandatory_shapes = {name: args[name] for name in mandatory_axes}
mandatory_shapes["num_images_per_prompt"] = args.get("num_images_per_prompt", 1) or 1
mandatory_shapes["sequence_length"] = args.get("sequence_length", None)
input_shapes = build_stable_diffusion_components_mandatory_shapes(**mandatory_shapes)
return input_shapes
def infer_stable_diffusion_shapes_from_diffusers(
input_shapes: Dict[str, Dict[str, int]],
model: Union["StableDiffusionPipeline", "StableDiffusionXLPipeline"],
has_controlnets: bool,
):
if model.tokenizer is not None:
max_sequence_length = model.tokenizer.model_max_length
elif hasattr(model, "tokenizer_2") and model.tokenizer_2 is not None:
max_sequence_length = model.tokenizer_2.model_max_length
else:
raise AttributeError(
f"Cannot infer max sequence_length from {type(model)} as there is no tokenizer as attribute."
)
vae_encoder_num_channels = model.vae.config.in_channels
vae_decoder_num_channels = model.vae.config.latent_channels
vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) or 8
height = input_shapes["unet_or_transformer"]["height"]
scaled_height = height // vae_scale_factor
width = input_shapes["unet_or_transformer"]["width"]
scaled_width = width // vae_scale_factor
# Text encoders
if input_shapes["text_encoder"].get("sequence_length") is None:
input_shapes["text_encoder"].update({"sequence_length": max_sequence_length})
if hasattr(model, "text_encoder_2"):
input_shapes["text_encoder_2"] = input_shapes["text_encoder"]
# UNet or Transformer
unet_or_transformer_name = "transformer" if hasattr(model, "transformer") else "unet"
unet_or_transformer_num_channels = getattr(model, unet_or_transformer_name).config.in_channels
input_shapes["unet_or_transformer"].update(
{
"num_channels": unet_or_transformer_num_channels,
"height": scaled_height,
"width": scaled_width,
}
)
if input_shapes["unet_or_transformer"].get("sequence_length") is None:
input_shapes["unet_or_transformer"]["sequence_length"] = max_sequence_length
input_shapes["unet_or_transformer"]["vae_scale_factor"] = vae_scale_factor
input_shapes[unet_or_transformer_name] = input_shapes.pop("unet_or_transformer")
if unet_or_transformer_name == "transformer":
input_shapes[unet_or_transformer_name]["encoder_hidden_size"] = model.text_encoder.config.hidden_size
# VAE
input_shapes["vae_encoder"].update({"num_channels": vae_encoder_num_channels, "height": height, "width": width})
input_shapes["vae_decoder"].update(
{"num_channels": vae_decoder_num_channels, "height": scaled_height, "width": scaled_width}
)
# ControlNet
if has_controlnets:
encoder_hidden_size = model.text_encoder.config.hidden_size
if hasattr(model, "text_encoder_2"):
encoder_hidden_size += model.text_encoder_2.config.hidden_size
input_shapes["controlnet"] = {
"batch_size": input_shapes[unet_or_transformer_name]["batch_size"],
"sequence_length": input_shapes[unet_or_transformer_name]["sequence_length"],
"num_channels": unet_or_transformer_num_channels,
"height": scaled_height,
"width": scaled_width,
"vae_scale_factor": vae_scale_factor,
"encoder_hidden_size": encoder_hidden_size,
}
# Image encoder
if getattr(model, "image_encoder", None):
input_shapes["image_encoder"] = {
"batch_size": input_shapes[unet_or_transformer_name]["batch_size"],
"num_channels": model.image_encoder.config.num_channels,
"width": model.image_encoder.config.image_size,
"height": model.image_encoder.config.image_size,
}
# IP-Adapter: add image_embeds as input for unet/transformer
# unet has `ip_adapter_image_embeds` with shape [batch_size, 1, (self.image_encoder.config.image_size//patch_size)**2+1, self.image_encoder.config.hidden_size] as input
if getattr(model.unet.config, "encoder_hid_dim_type", None) == "ip_image_proj":
input_shapes[unet_or_transformer_name]["image_encoder_shapes"] = ImageEncoderArguments(
sequence_length=model.image_encoder.vision_model.embeddings.position_embedding.weight.shape[0],
hidden_size=model.image_encoder.vision_model.embeddings.position_embedding.weight.shape[1],
projection_dim=getattr(model.image_encoder.config, "projection_dim", None),
)
# Format with `InputShapesArguments`
for sub_model_name in input_shapes.keys():
input_shapes[sub_model_name] = InputShapesArguments(**input_shapes[sub_model_name])
return input_shapes
def get_submodels_and_neuron_configs(
model: Union["PreTrainedModel", "DiffusionPipeline"],
input_shapes: Dict[str, int],
task: str,
output: Path,
library_name: str,
tensor_parallel_size: int = 1,
subfolder: str = "",
trust_remote_code: bool = False,
dynamic_batch_size: bool = False,
model_name_or_path: Optional[Union[str, Path]] = None,
submodels: Optional[Dict[str, Union[Path, str]]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
controlnet_ids: Optional[Union[str, List[str]]] = None,
lora_args: Optional[LoRAAdapterArguments] = None,
):
is_encoder_decoder = (
getattr(model.config, "is_encoder_decoder", False) if isinstance(model.config, PretrainedConfig) else False
)
if library_name == "diffusers":
# TODO: Enable optional outputs for Stable Diffusion
if output_attentions:
raise ValueError(f"`output_attentions`is not supported by the {task} task yet.")
models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_stable_diffusion(
model=model,
input_shapes=input_shapes,
output=output,
dynamic_batch_size=dynamic_batch_size,
submodels=submodels,
output_hidden_states=output_hidden_states,
controlnet_ids=controlnet_ids,
lora_args=lora_args,
)
elif is_encoder_decoder:
optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states}
preprocessors = maybe_load_preprocessors(
src_name_or_path=model_name_or_path,
subfolder=subfolder,
trust_remote_code=trust_remote_code,
)
models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_encoder_decoder(
model=model,
input_shapes=input_shapes,
tensor_parallel_size=tensor_parallel_size,
task=task,
output=output,
dynamic_batch_size=dynamic_batch_size,
model_name_or_path=model_name_or_path,
preprocessors=preprocessors,
**optional_outputs,
)
else:
# TODO: Enable optional outputs for encoders
if output_attentions or output_hidden_states:
raise ValueError(
f"`output_attentions` and `output_hidden_states` are not supported by the {task} task yet."
)
neuron_config_constructor = TasksManager.get_exporter_config_constructor(
model=model,
exporter="neuron",
task=task,
library_name=library_name,
)
input_shapes = check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes)
input_shapes = InputShapesArguments(**input_shapes)
neuron_config = neuron_config_constructor(
model.config, dynamic_batch_size=dynamic_batch_size, input_shapes=input_shapes
)
model_name = getattr(model, "name_or_path", None) or model_name_or_path
model_name = model_name.split("/")[-1] if model_name else model.config.model_type
output_model_names = {model_name: "model.neuron"}
models_and_neuron_configs = {model_name: (model, neuron_config)}
maybe_save_preprocessors(model_name_or_path, output, src_subfolder=subfolder)
return models_and_neuron_configs, output_model_names
def _get_submodels_and_neuron_configs_for_stable_diffusion(
model: Union["PreTrainedModel", "DiffusionPipeline"],
input_shapes: Dict[str, int],
output: Path,
dynamic_batch_size: bool = False,
submodels: Optional[Dict[str, Union[Path, str]]] = None,
output_hidden_states: bool = False,
controlnet_ids: Optional[Union[str, List[str]]] = None,
lora_args: Optional[LoRAAdapterArguments] = None,
):
check_compiler_compatibility_for_stable_diffusion()
model = replace_stable_diffusion_submodels(model, submodels)
if is_neuron_available():
raise RuntimeError(
"Stable diffusion export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead."
)
input_shapes = infer_stable_diffusion_shapes_from_diffusers(
input_shapes=input_shapes,
model=model,
has_controlnets=controlnet_ids is not None,
)
# Saving the model config and preprocessor as this is needed sometimes.
model.scheduler.save_pretrained(output.joinpath("scheduler"))
if getattr(model, "tokenizer", None) is not None:
model.tokenizer.save_pretrained(output.joinpath("tokenizer"))
if getattr(model, "tokenizer_2", None) is not None:
model.tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
if getattr(model, "tokenizer_3", None) is not None:
model.tokenizer_3.save_pretrained(output.joinpath("tokenizer_3"))
if getattr(model, "feature_extractor", None) is not None:
model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
model.save_config(output)
models_and_neuron_configs = get_diffusion_models_for_export(
pipeline=model,
text_encoder_input_shapes=input_shapes["text_encoder"],
unet_input_shapes=input_shapes.get("unet", None),
transformer_input_shapes=input_shapes.get("transformer", None),
vae_encoder_input_shapes=input_shapes["vae_encoder"],
vae_decoder_input_shapes=input_shapes["vae_decoder"],
lora_args=lora_args,
dynamic_batch_size=dynamic_batch_size,
output_hidden_states=output_hidden_states,
controlnet_ids=controlnet_ids,
controlnet_input_shapes=input_shapes.get("controlnet", None),
image_encoder_input_shapes=input_shapes.get("image_encoder", None),
)
output_model_names = {
DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME),
DIFFUSION_MODEL_VAE_DECODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_DECODER_NAME, NEURON_FILE_NAME),
}
if getattr(model, "text_encoder", None) is not None:
output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_NAME] = os.path.join(
DIFFUSION_MODEL_TEXT_ENCODER_NAME, NEURON_FILE_NAME
)
if getattr(model, "text_encoder_2", None) is not None:
output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join(
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME
)
if getattr(model, "unet", None) is not None:
output_model_names[DIFFUSION_MODEL_UNET_NAME] = os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME)
if getattr(model, "transformer", None) is not None:
output_model_names[DIFFUSION_MODEL_TRANSFORMER_NAME] = os.path.join(
DIFFUSION_MODEL_TRANSFORMER_NAME, NEURON_FILE_NAME
)
if getattr(model, "image_encoder", None) is not None:
output_model_names["image_encoder"] = os.path.join("image_encoder", NEURON_FILE_NAME)
# ControlNet models
if controlnet_ids:
if isinstance(controlnet_ids, str):
controlnet_ids = [controlnet_ids]
for idx in range(len(controlnet_ids)):
controlnet_name = DIFFUSION_MODEL_CONTROLNET_NAME + "_" + str(idx)
output_model_names[controlnet_name] = os.path.join(controlnet_name, NEURON_FILE_NAME)
del model
return models_and_neuron_configs, output_model_names
def _get_submodels_and_neuron_configs_for_encoder_decoder(
model: "PreTrainedModel",
input_shapes: Dict[str, int],
tensor_parallel_size: int,
task: str,
output: Path,
preprocessors: Optional[List] = None,
dynamic_batch_size: bool = False,
model_name_or_path: Optional[Union[str, Path]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
if is_neuron_available():
raise RuntimeError(
"Encoder-decoder models export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead."
)
models_and_neuron_configs = get_encoder_decoder_models_for_export(
model=model,
task=task,
tensor_parallel_size=tensor_parallel_size,
dynamic_batch_size=dynamic_batch_size,
input_shapes=input_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
model_name_or_path=model_name_or_path,
preprocessors=preprocessors,
)
output_model_names = {
ENCODER_NAME: os.path.join(ENCODER_NAME, NEURON_FILE_NAME),
DECODER_NAME: os.path.join(DECODER_NAME, NEURON_FILE_NAME),
}
model.config.save_pretrained(output)
model.generation_config.save_pretrained(output)
maybe_save_preprocessors(model_name_or_path, output)
return models_and_neuron_configs, output_model_names
def load_models_and_neuron_configs(
model_name_or_path: str,
output: Path,
model: Optional[Union["PreTrainedModel", "ModelMixin"]],
task: str,
dynamic_batch_size: bool,
cache_dir: Optional[str],
trust_remote_code: bool,
subfolder: str,
revision: str,
library_name: str,
force_download: bool,
local_files_only: bool,
token: Optional[Union[bool, str]],
submodels: Optional[Dict[str, Union[Path, str]]],
torch_dtype: Optional[Union[str, torch.dtype]] = None,
tensor_parallel_size: int = 1,
controlnet_ids: Optional[Union[str, List[str]]] = None,
lora_args: Optional[LoRAAdapterArguments] = None,
ip_adapter_args: Optional[IPAdapterArguments] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
**input_shapes,
):
model_kwargs = {
"task": task,
"model_name_or_path": model_name_or_path,
"subfolder": subfolder,
"revision": revision,
"cache_dir": cache_dir,
"token": token,
"local_files_only": local_files_only,
"force_download": force_download,
"trust_remote_code": trust_remote_code,
"framework": "pt",
"library_name": library_name,
"torch_dtype": torch_dtype,
}
if model is None:
model = TasksManager.get_model_from_task(**model_kwargs)
# Load IP-Adapter if it exists
if ip_adapter_args is not None and not all(
getattr(ip_adapter_args, field.name) is None for field in fields(ip_adapter_args)
):
model.load_ip_adapter(
ip_adapter_args.model_id, subfolder=ip_adapter_args.subfolder, weight_name=ip_adapter_args.weight_name
)
model.set_ip_adapter_scale(scale=ip_adapter_args.scale)
models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs(
model=model,
input_shapes=input_shapes,
tensor_parallel_size=tensor_parallel_size,
task=task,
library_name=library_name,
output=output,
subfolder=subfolder,
trust_remote_code=trust_remote_code,
dynamic_batch_size=dynamic_batch_size,
model_name_or_path=model_name_or_path,
submodels=submodels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
controlnet_ids=controlnet_ids,
lora_args=lora_args,
)
return models_and_neuron_configs, output_model_names
def main_export(
model_name_or_path: str,
output: Union[str, Path],
compiler_kwargs: Dict[str, Any],
torch_dtype: Optional[Union[str, torch.dtype]] = None,
tensor_parallel_size: int = 1,
model: Optional[Union["PreTrainedModel", "ModelMixin"]] = None,
task: str = "auto",
dynamic_batch_size: bool = False,
atol: Optional[float] = None,
cache_dir: Optional[str] = None,
disable_neuron_cache: Optional[bool] = False,
compiler_workdir: Optional[Union[str, Path]] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
trust_remote_code: bool = False,
subfolder: str = "",
revision: str = "main",
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[bool, str]] = None,
do_validation: bool = True,
submodels: Optional[Dict[str, Union[Path, str]]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
library_name: Optional[str] = None,
controlnet_ids: Optional[Union[str, List[str]]] = None,
lora_args: Optional[LoRAAdapterArguments] = None,
ip_adapter_args: Optional[IPAdapterArguments] = None,
**input_shapes,
):
output = Path(output)
torch_dtype = map_torch_dtype(torch_dtype)
if not output.parent.exists():
output.parent.mkdir(parents=True)
task = TasksManager.map_from_synonym(task)
if library_name is None:
library_name = TasksManager.infer_library_from_model(
model_name_or_path, revision=revision, cache_dir=cache_dir, token=token
)
models_and_neuron_configs, output_model_names = load_models_and_neuron_configs(
model_name_or_path=model_name_or_path,
output=output,
model=model,
torch_dtype=torch_dtype,
tensor_parallel_size=tensor_parallel_size,
task=task,
dynamic_batch_size=dynamic_batch_size,
cache_dir=cache_dir,
trust_remote_code=trust_remote_code,
subfolder=subfolder,
revision=revision,
library_name=library_name,
force_download=force_download,
local_files_only=local_files_only,
token=token,
submodels=submodels,
lora_args=lora_args,
ip_adapter_args=ip_adapter_args,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
controlnet_ids=controlnet_ids,
**input_shapes,
)
_, neuron_outputs = export_models(
models_and_neuron_configs=models_and_neuron_configs,
task=task,
output_dir=output,
disable_neuron_cache=disable_neuron_cache,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
output_file_names=output_model_names,
compiler_kwargs=compiler_kwargs,
model_name_or_path=model_name_or_path,
)
# Validate compiled model
if do_validation and tensor_parallel_size > 1:
# TODO: support the validation of tp models.
logger.warning(
"The validation is not yet supported for tensor parallel model, the validation will be turned off."
)
do_validation = False
if do_validation is True:
try:
validate_models_outputs(
models_and_neuron_configs=models_and_neuron_configs,
neuron_named_outputs=neuron_outputs,
output_dir=output,
atol=atol,
neuron_files_subpaths=output_model_names,
)
logger.info(
f"The {NEURON_COMPILER} export succeeded and the exported model was saved at: {output.as_posix()}"
)
except ShapeError as e:
raise e
except AtolError as e:
logger.warning(
f"The {NEURON_COMPILER} export succeeded with the warning: {e}.\n The exported model was saved at: "
f"{output.as_posix()}"
)
except OutputMatchError as e:
logger.warning(
f"The {NEURON_COMPILER} export succeeded with the warning: {e}.\n The exported model was saved at: "
f"{output.as_posix()}"
)
except Exception as e:
logger.error(
f"An error occurred with the error message: {e}.\n The exported model was saved at: {output.as_posix()}"
)
def maybe_export_from_neuron_model_class(
model: str,
output: Union[str, Path],
task: str = "auto",
cache_dir: Optional[str] = None,
subfolder: str = "",
trust_remote_code: bool = False,
**kwargs,
):
"""Export the model from the neuron model class if it exists."""
if task == "auto":
task = infer_task(model)
output = Path(output)
# Remove None values from the kwargs
kwargs = {key: value for key, value in kwargs.items() if value is not None}
# Also remove some arguments that are not supported in this context
kwargs.pop("disable_neuron_cache", None)
kwargs.pop("inline_weights_neff", None)
kwargs.pop("O1", None)
kwargs.pop("O2", None)
kwargs.pop("O3", None)
kwargs.pop("disable_validation", None)
kwargs.pop("dynamic_batch_size", None)
kwargs.pop("output_hidden_states", None)
kwargs.pop("output_attentions", None)
kwargs.pop("tensor_parallel_size", None)
# Fetch the model config
config = AutoConfig.from_pretrained(model)
# Check if we have an auto-model class for the model_type and task
if not has_neuron_model_class(model_type=config.model_type, task=task, mode="inference"):
return False
neuron_model_class = get_neuron_model_class(model_type=config.model_type, task=task, mode="inference")
neuron_model = neuron_model_class.from_pretrained(
model_id=model,
export=True,
cache_dir=cache_dir,
subfolder=subfolder,
config=config,
trust_remote_code=trust_remote_code,
load_weights=False, # Reduce model size for nxd models
**kwargs,
)
if not output.parent.exists():
output.parent.mkdir(parents=True)
neuron_model.save_pretrained(output)
try:
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=trust_remote_code)
tokenizer.save_pretrained(output)
except Exception:
logger.info(f"No tokenizer found while exporting {model}.")
return True
def main():
parser = ArgumentParser(f"Hugging Face Optimum {NEURON_COMPILER} exporter")
parse_args_neuron(parser)
# Retrieve CLI arguments
args = parser.parse_args()
task = infer_task(args.model) if args.task == "auto" else args.task
library_name = TasksManager.infer_library_from_model(args.model, cache_dir=args.cache_dir)
if library_name == "diffusers":
input_shapes = normalize_stable_diffusion_input_shapes(args)
submodels = {"unet": args.unet}
elif library_name == "sentence_transformers":
input_shapes = normalize_sentence_transformers_input_shapes(args)
submodels = None
else:
# New export mode using dedicated neuron model classes
kwargs = vars(args).copy()
if maybe_export_from_neuron_model_class(**kwargs):
return
# Fallback to legacy export
input_shapes = get_input_shapes(task, args)
submodels = None
disable_neuron_cache = args.disable_neuron_cache
compiler_kwargs = infer_compiler_kwargs(args)
optional_outputs = customize_optional_outputs(args)
optlevel = parse_optlevel(args)
lora_args = LoRAAdapterArguments(
model_ids=getattr(args, "lora_model_ids", None),
weight_names=getattr(args, "lora_weight_names", None),
adapter_names=getattr(args, "lora_adapter_names", None),
scales=getattr(args, "lora_scales", None),
)
ip_adapter_args = IPAdapterArguments(
model_id=getattr(args, "ip_adapter_id", None),
subfolder=getattr(args, "ip_adapter_subfolder", None),
weight_name=getattr(args, "ip_adapter_weight_name", None),
scale=getattr(args, "ip_adapter_scale", None),
)
main_export(
model_name_or_path=args.model,
output=args.output,
compiler_kwargs=compiler_kwargs,
torch_dtype=args.torch_dtype,
tensor_parallel_size=args.tensor_parallel_size,
task=task,
dynamic_batch_size=args.dynamic_batch_size,
atol=args.atol,
cache_dir=args.cache_dir,
disable_neuron_cache=disable_neuron_cache,
compiler_workdir=args.compiler_workdir,
inline_weights_to_neff=args.inline_weights_neff,
optlevel=optlevel,
trust_remote_code=args.trust_remote_code,
subfolder=args.subfolder,
do_validation=not args.disable_validation,
submodels=submodels,
library_name=library_name,
controlnet_ids=getattr(args, "controlnet_ids", None),
lora_args=lora_args,
ip_adapter_args=ip_adapter_args,
**optional_outputs,
**input_shapes,
)
if __name__ == "__main__":
main()