optimum/exporters/neuron/__main__.py (685 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Entry point to the optimum.exporters.neuron command line.""" import argparse import inspect import os os.environ["TORCHDYNAMO_DISABLE"] = "1" # Always turn off torchdynamo as it's incompatible with neuron from argparse import ArgumentParser from dataclasses import fields from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch from requests.exceptions import ConnectionError as RequestsConnectionError from transformers import AutoConfig, AutoTokenizer, PretrainedConfig from optimum.exporters.error_utils import AtolError, OutputMatchError, ShapeError from optimum.exporters.tasks import TasksManager from optimum.utils import is_diffusers_available, logging from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from ...neuron.models.auto_model import get_neuron_model_class, has_neuron_model_class from ...neuron.utils import ( DECODER_NAME, DIFFUSION_MODEL_CONTROLNET_NAME, DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, DIFFUSION_MODEL_TEXT_ENCODER_NAME, DIFFUSION_MODEL_TRANSFORMER_NAME, DIFFUSION_MODEL_UNET_NAME, DIFFUSION_MODEL_VAE_DECODER_NAME, DIFFUSION_MODEL_VAE_ENCODER_NAME, ENCODER_NAME, NEURON_FILE_NAME, ImageEncoderArguments, InputShapesArguments, IPAdapterArguments, LoRAAdapterArguments, is_neuron_available, is_neuronx_available, map_torch_dtype, ) from ...neuron.utils.version_utils import ( check_compiler_compatibility_for_stable_diffusion, ) from .base import NeuronExportConfig from .convert import export_models, validate_models_outputs from .model_configs import * # noqa: F403 from .utils import ( build_stable_diffusion_components_mandatory_shapes, check_mandatory_input_shapes, get_diffusion_models_for_export, get_encoder_decoder_models_for_export, replace_stable_diffusion_submodels, ) if is_neuron_available(): from ...commands.export.neuron import parse_args_neuron NEURON_COMPILER = "Neuron" if is_neuronx_available(): from ...commands.export.neuronx import parse_args_neuronx as parse_args_neuron # noqa: F811 NEURON_COMPILER = "Neuronx" if is_diffusers_available(): from diffusers import StableDiffusionXLPipeline if TYPE_CHECKING: from transformers import PreTrainedModel if is_diffusers_available(): from diffusers import DiffusionPipeline, ModelMixin, StableDiffusionPipeline logger = logging.get_logger() logger.setLevel(logging.INFO) def infer_compiler_kwargs(args: argparse.Namespace) -> Dict[str, Any]: # infer compiler kwargs auto_cast = None if args.auto_cast == "none" else args.auto_cast auto_cast_type = None if auto_cast is None else args.auto_cast_type compiler_kwargs = {"auto_cast": auto_cast, "auto_cast_type": auto_cast_type} if hasattr(args, "disable_fast_relayout"): compiler_kwargs["disable_fast_relayout"] = getattr(args, "disable_fast_relayout") if hasattr(args, "disable_fallback"): compiler_kwargs["disable_fallback"] = getattr(args, "disable_fallback") return compiler_kwargs def infer_task(model_name_or_path: str) -> str: try: return TasksManager.infer_task_from_model(model_name_or_path) except KeyError as e: raise KeyError( "The task could not be automatically inferred. Please provide the argument --task with the task " f"from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) except RequestsConnectionError as e: raise RequestsConnectionError( f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) # This function is not applicable for diffusers / sentence transformers models def get_input_shapes(task: str, args: argparse.Namespace) -> Dict[str, int]: neuron_config_constructor = get_neuron_config_class(task, args.model) input_args = neuron_config_constructor.func.get_input_args_for_task(task) return {name: getattr(args, name) for name in input_args} def get_neuron_config_class(task: str, model_id: str) -> NeuronExportConfig: config = AutoConfig.from_pretrained(model_id) model_type = config.model_type.replace("_", "-") if config.is_encoder_decoder: model_type = model_type + "-encoder" neuron_config_constructor = TasksManager.get_exporter_config_constructor( model_type=model_type, exporter="neuron", task=task, library_name="transformers", ) return neuron_config_constructor def normalize_sentence_transformers_input_shapes(args: argparse.Namespace) -> Dict[str, int]: args = vars(args) if isinstance(args, argparse.Namespace) else args if "clip" in args.get("model", "").lower(): mandatory_axes = {"text_batch_size", "image_batch_size", "sequence_length", "num_channels", "width", "height"} else: mandatory_axes = {"batch_size", "sequence_length"} if not mandatory_axes.issubset(set(args.keys())): raise AttributeError( f"Shape of {mandatory_axes} are mandatory for neuron compilation, while {mandatory_axes.difference(args.keys())} are not given." ) mandatory_shapes = {name: args[name] for name in mandatory_axes} return mandatory_shapes def customize_optional_outputs(args: argparse.Namespace) -> Dict[str, bool]: """ Customize optional outputs of the traced model, eg. if `output_attentions=True`, the attentions tensors will be traced. """ possible_outputs = ["output_attentions", "output_hidden_states"] customized_outputs = {} for name in possible_outputs: customized_outputs[name] = getattr(args, name, False) return customized_outputs def parse_optlevel(args: argparse.Namespace) -> Dict[str, bool]: """ (NEURONX ONLY) Parse the level of optimization the compiler should perform. If not specified apply `O2`(the best balance between model performance and compile time). """ if is_neuronx_available(): if args.O1: optlevel = "1" elif args.O2: optlevel = "2" elif args.O3: optlevel = "3" else: optlevel = "2" else: optlevel = None return optlevel def normalize_stable_diffusion_input_shapes( args: argparse.Namespace, ) -> Dict[str, Dict[str, int]]: args = vars(args) if isinstance(args, argparse.Namespace) else args mandatory_axes = set(getattr(inspect.getfullargspec(build_stable_diffusion_components_mandatory_shapes), "args")) mandatory_axes = mandatory_axes - { "sequence_length", # `sequence_length` is optional, diffusers will pad it to the max if not provided. # remove number of channels. "unet_or_transformer_num_channels", "vae_encoder_num_channels", "vae_decoder_num_channels", "num_images_per_prompt", # default to 1 } if not mandatory_axes.issubset(set(args.keys())): raise AttributeError( f"Shape of {mandatory_axes} are mandatory for neuron compilation, while {mandatory_axes.difference(args.keys())} are not given." ) mandatory_shapes = {name: args[name] for name in mandatory_axes} mandatory_shapes["num_images_per_prompt"] = args.get("num_images_per_prompt", 1) or 1 mandatory_shapes["sequence_length"] = args.get("sequence_length", None) input_shapes = build_stable_diffusion_components_mandatory_shapes(**mandatory_shapes) return input_shapes def infer_stable_diffusion_shapes_from_diffusers( input_shapes: Dict[str, Dict[str, int]], model: Union["StableDiffusionPipeline", "StableDiffusionXLPipeline"], has_controlnets: bool, ): if model.tokenizer is not None: max_sequence_length = model.tokenizer.model_max_length elif hasattr(model, "tokenizer_2") and model.tokenizer_2 is not None: max_sequence_length = model.tokenizer_2.model_max_length else: raise AttributeError( f"Cannot infer max sequence_length from {type(model)} as there is no tokenizer as attribute." ) vae_encoder_num_channels = model.vae.config.in_channels vae_decoder_num_channels = model.vae.config.latent_channels vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) or 8 height = input_shapes["unet_or_transformer"]["height"] scaled_height = height // vae_scale_factor width = input_shapes["unet_or_transformer"]["width"] scaled_width = width // vae_scale_factor # Text encoders if input_shapes["text_encoder"].get("sequence_length") is None: input_shapes["text_encoder"].update({"sequence_length": max_sequence_length}) if hasattr(model, "text_encoder_2"): input_shapes["text_encoder_2"] = input_shapes["text_encoder"] # UNet or Transformer unet_or_transformer_name = "transformer" if hasattr(model, "transformer") else "unet" unet_or_transformer_num_channels = getattr(model, unet_or_transformer_name).config.in_channels input_shapes["unet_or_transformer"].update( { "num_channels": unet_or_transformer_num_channels, "height": scaled_height, "width": scaled_width, } ) if input_shapes["unet_or_transformer"].get("sequence_length") is None: input_shapes["unet_or_transformer"]["sequence_length"] = max_sequence_length input_shapes["unet_or_transformer"]["vae_scale_factor"] = vae_scale_factor input_shapes[unet_or_transformer_name] = input_shapes.pop("unet_or_transformer") if unet_or_transformer_name == "transformer": input_shapes[unet_or_transformer_name]["encoder_hidden_size"] = model.text_encoder.config.hidden_size # VAE input_shapes["vae_encoder"].update({"num_channels": vae_encoder_num_channels, "height": height, "width": width}) input_shapes["vae_decoder"].update( {"num_channels": vae_decoder_num_channels, "height": scaled_height, "width": scaled_width} ) # ControlNet if has_controlnets: encoder_hidden_size = model.text_encoder.config.hidden_size if hasattr(model, "text_encoder_2"): encoder_hidden_size += model.text_encoder_2.config.hidden_size input_shapes["controlnet"] = { "batch_size": input_shapes[unet_or_transformer_name]["batch_size"], "sequence_length": input_shapes[unet_or_transformer_name]["sequence_length"], "num_channels": unet_or_transformer_num_channels, "height": scaled_height, "width": scaled_width, "vae_scale_factor": vae_scale_factor, "encoder_hidden_size": encoder_hidden_size, } # Image encoder if getattr(model, "image_encoder", None): input_shapes["image_encoder"] = { "batch_size": input_shapes[unet_or_transformer_name]["batch_size"], "num_channels": model.image_encoder.config.num_channels, "width": model.image_encoder.config.image_size, "height": model.image_encoder.config.image_size, } # IP-Adapter: add image_embeds as input for unet/transformer # unet has `ip_adapter_image_embeds` with shape [batch_size, 1, (self.image_encoder.config.image_size//patch_size)**2+1, self.image_encoder.config.hidden_size] as input if getattr(model.unet.config, "encoder_hid_dim_type", None) == "ip_image_proj": input_shapes[unet_or_transformer_name]["image_encoder_shapes"] = ImageEncoderArguments( sequence_length=model.image_encoder.vision_model.embeddings.position_embedding.weight.shape[0], hidden_size=model.image_encoder.vision_model.embeddings.position_embedding.weight.shape[1], projection_dim=getattr(model.image_encoder.config, "projection_dim", None), ) # Format with `InputShapesArguments` for sub_model_name in input_shapes.keys(): input_shapes[sub_model_name] = InputShapesArguments(**input_shapes[sub_model_name]) return input_shapes def get_submodels_and_neuron_configs( model: Union["PreTrainedModel", "DiffusionPipeline"], input_shapes: Dict[str, int], task: str, output: Path, library_name: str, tensor_parallel_size: int = 1, subfolder: str = "", trust_remote_code: bool = False, dynamic_batch_size: bool = False, model_name_or_path: Optional[Union[str, Path]] = None, submodels: Optional[Dict[str, Union[Path, str]]] = None, output_attentions: bool = False, output_hidden_states: bool = False, controlnet_ids: Optional[Union[str, List[str]]] = None, lora_args: Optional[LoRAAdapterArguments] = None, ): is_encoder_decoder = ( getattr(model.config, "is_encoder_decoder", False) if isinstance(model.config, PretrainedConfig) else False ) if library_name == "diffusers": # TODO: Enable optional outputs for Stable Diffusion if output_attentions: raise ValueError(f"`output_attentions`is not supported by the {task} task yet.") models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_stable_diffusion( model=model, input_shapes=input_shapes, output=output, dynamic_batch_size=dynamic_batch_size, submodels=submodels, output_hidden_states=output_hidden_states, controlnet_ids=controlnet_ids, lora_args=lora_args, ) elif is_encoder_decoder: optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states} preprocessors = maybe_load_preprocessors( src_name_or_path=model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code, ) models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_encoder_decoder( model=model, input_shapes=input_shapes, tensor_parallel_size=tensor_parallel_size, task=task, output=output, dynamic_batch_size=dynamic_batch_size, model_name_or_path=model_name_or_path, preprocessors=preprocessors, **optional_outputs, ) else: # TODO: Enable optional outputs for encoders if output_attentions or output_hidden_states: raise ValueError( f"`output_attentions` and `output_hidden_states` are not supported by the {task} task yet." ) neuron_config_constructor = TasksManager.get_exporter_config_constructor( model=model, exporter="neuron", task=task, library_name=library_name, ) input_shapes = check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes) input_shapes = InputShapesArguments(**input_shapes) neuron_config = neuron_config_constructor( model.config, dynamic_batch_size=dynamic_batch_size, input_shapes=input_shapes ) model_name = getattr(model, "name_or_path", None) or model_name_or_path model_name = model_name.split("/")[-1] if model_name else model.config.model_type output_model_names = {model_name: "model.neuron"} models_and_neuron_configs = {model_name: (model, neuron_config)} maybe_save_preprocessors(model_name_or_path, output, src_subfolder=subfolder) return models_and_neuron_configs, output_model_names def _get_submodels_and_neuron_configs_for_stable_diffusion( model: Union["PreTrainedModel", "DiffusionPipeline"], input_shapes: Dict[str, int], output: Path, dynamic_batch_size: bool = False, submodels: Optional[Dict[str, Union[Path, str]]] = None, output_hidden_states: bool = False, controlnet_ids: Optional[Union[str, List[str]]] = None, lora_args: Optional[LoRAAdapterArguments] = None, ): check_compiler_compatibility_for_stable_diffusion() model = replace_stable_diffusion_submodels(model, submodels) if is_neuron_available(): raise RuntimeError( "Stable diffusion export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead." ) input_shapes = infer_stable_diffusion_shapes_from_diffusers( input_shapes=input_shapes, model=model, has_controlnets=controlnet_ids is not None, ) # Saving the model config and preprocessor as this is needed sometimes. model.scheduler.save_pretrained(output.joinpath("scheduler")) if getattr(model, "tokenizer", None) is not None: model.tokenizer.save_pretrained(output.joinpath("tokenizer")) if getattr(model, "tokenizer_2", None) is not None: model.tokenizer_2.save_pretrained(output.joinpath("tokenizer_2")) if getattr(model, "tokenizer_3", None) is not None: model.tokenizer_3.save_pretrained(output.joinpath("tokenizer_3")) if getattr(model, "feature_extractor", None) is not None: model.feature_extractor.save_pretrained(output.joinpath("feature_extractor")) model.save_config(output) models_and_neuron_configs = get_diffusion_models_for_export( pipeline=model, text_encoder_input_shapes=input_shapes["text_encoder"], unet_input_shapes=input_shapes.get("unet", None), transformer_input_shapes=input_shapes.get("transformer", None), vae_encoder_input_shapes=input_shapes["vae_encoder"], vae_decoder_input_shapes=input_shapes["vae_decoder"], lora_args=lora_args, dynamic_batch_size=dynamic_batch_size, output_hidden_states=output_hidden_states, controlnet_ids=controlnet_ids, controlnet_input_shapes=input_shapes.get("controlnet", None), image_encoder_input_shapes=input_shapes.get("image_encoder", None), ) output_model_names = { DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME), DIFFUSION_MODEL_VAE_DECODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_DECODER_NAME, NEURON_FILE_NAME), } if getattr(model, "text_encoder", None) is not None: output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_NAME] = os.path.join( DIFFUSION_MODEL_TEXT_ENCODER_NAME, NEURON_FILE_NAME ) if getattr(model, "text_encoder_2", None) is not None: output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join( DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME ) if getattr(model, "unet", None) is not None: output_model_names[DIFFUSION_MODEL_UNET_NAME] = os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME) if getattr(model, "transformer", None) is not None: output_model_names[DIFFUSION_MODEL_TRANSFORMER_NAME] = os.path.join( DIFFUSION_MODEL_TRANSFORMER_NAME, NEURON_FILE_NAME ) if getattr(model, "image_encoder", None) is not None: output_model_names["image_encoder"] = os.path.join("image_encoder", NEURON_FILE_NAME) # ControlNet models if controlnet_ids: if isinstance(controlnet_ids, str): controlnet_ids = [controlnet_ids] for idx in range(len(controlnet_ids)): controlnet_name = DIFFUSION_MODEL_CONTROLNET_NAME + "_" + str(idx) output_model_names[controlnet_name] = os.path.join(controlnet_name, NEURON_FILE_NAME) del model return models_and_neuron_configs, output_model_names def _get_submodels_and_neuron_configs_for_encoder_decoder( model: "PreTrainedModel", input_shapes: Dict[str, int], tensor_parallel_size: int, task: str, output: Path, preprocessors: Optional[List] = None, dynamic_batch_size: bool = False, model_name_or_path: Optional[Union[str, Path]] = None, output_attentions: bool = False, output_hidden_states: bool = False, ): if is_neuron_available(): raise RuntimeError( "Encoder-decoder models export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead." ) models_and_neuron_configs = get_encoder_decoder_models_for_export( model=model, task=task, tensor_parallel_size=tensor_parallel_size, dynamic_batch_size=dynamic_batch_size, input_shapes=input_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, model_name_or_path=model_name_or_path, preprocessors=preprocessors, ) output_model_names = { ENCODER_NAME: os.path.join(ENCODER_NAME, NEURON_FILE_NAME), DECODER_NAME: os.path.join(DECODER_NAME, NEURON_FILE_NAME), } model.config.save_pretrained(output) model.generation_config.save_pretrained(output) maybe_save_preprocessors(model_name_or_path, output) return models_and_neuron_configs, output_model_names def load_models_and_neuron_configs( model_name_or_path: str, output: Path, model: Optional[Union["PreTrainedModel", "ModelMixin"]], task: str, dynamic_batch_size: bool, cache_dir: Optional[str], trust_remote_code: bool, subfolder: str, revision: str, library_name: str, force_download: bool, local_files_only: bool, token: Optional[Union[bool, str]], submodels: Optional[Dict[str, Union[Path, str]]], torch_dtype: Optional[Union[str, torch.dtype]] = None, tensor_parallel_size: int = 1, controlnet_ids: Optional[Union[str, List[str]]] = None, lora_args: Optional[LoRAAdapterArguments] = None, ip_adapter_args: Optional[IPAdapterArguments] = None, output_attentions: bool = False, output_hidden_states: bool = False, **input_shapes, ): model_kwargs = { "task": task, "model_name_or_path": model_name_or_path, "subfolder": subfolder, "revision": revision, "cache_dir": cache_dir, "token": token, "local_files_only": local_files_only, "force_download": force_download, "trust_remote_code": trust_remote_code, "framework": "pt", "library_name": library_name, "torch_dtype": torch_dtype, } if model is None: model = TasksManager.get_model_from_task(**model_kwargs) # Load IP-Adapter if it exists if ip_adapter_args is not None and not all( getattr(ip_adapter_args, field.name) is None for field in fields(ip_adapter_args) ): model.load_ip_adapter( ip_adapter_args.model_id, subfolder=ip_adapter_args.subfolder, weight_name=ip_adapter_args.weight_name ) model.set_ip_adapter_scale(scale=ip_adapter_args.scale) models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs( model=model, input_shapes=input_shapes, tensor_parallel_size=tensor_parallel_size, task=task, library_name=library_name, output=output, subfolder=subfolder, trust_remote_code=trust_remote_code, dynamic_batch_size=dynamic_batch_size, model_name_or_path=model_name_or_path, submodels=submodels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, controlnet_ids=controlnet_ids, lora_args=lora_args, ) return models_and_neuron_configs, output_model_names def main_export( model_name_or_path: str, output: Union[str, Path], compiler_kwargs: Dict[str, Any], torch_dtype: Optional[Union[str, torch.dtype]] = None, tensor_parallel_size: int = 1, model: Optional[Union["PreTrainedModel", "ModelMixin"]] = None, task: str = "auto", dynamic_batch_size: bool = False, atol: Optional[float] = None, cache_dir: Optional[str] = None, disable_neuron_cache: Optional[bool] = False, compiler_workdir: Optional[Union[str, Path]] = None, inline_weights_to_neff: bool = True, optlevel: str = "2", trust_remote_code: bool = False, subfolder: str = "", revision: str = "main", force_download: bool = False, local_files_only: bool = False, token: Optional[Union[bool, str]] = None, do_validation: bool = True, submodels: Optional[Dict[str, Union[Path, str]]] = None, output_attentions: bool = False, output_hidden_states: bool = False, library_name: Optional[str] = None, controlnet_ids: Optional[Union[str, List[str]]] = None, lora_args: Optional[LoRAAdapterArguments] = None, ip_adapter_args: Optional[IPAdapterArguments] = None, **input_shapes, ): output = Path(output) torch_dtype = map_torch_dtype(torch_dtype) if not output.parent.exists(): output.parent.mkdir(parents=True) task = TasksManager.map_from_synonym(task) if library_name is None: library_name = TasksManager.infer_library_from_model( model_name_or_path, revision=revision, cache_dir=cache_dir, token=token ) models_and_neuron_configs, output_model_names = load_models_and_neuron_configs( model_name_or_path=model_name_or_path, output=output, model=model, torch_dtype=torch_dtype, tensor_parallel_size=tensor_parallel_size, task=task, dynamic_batch_size=dynamic_batch_size, cache_dir=cache_dir, trust_remote_code=trust_remote_code, subfolder=subfolder, revision=revision, library_name=library_name, force_download=force_download, local_files_only=local_files_only, token=token, submodels=submodels, lora_args=lora_args, ip_adapter_args=ip_adapter_args, output_attentions=output_attentions, output_hidden_states=output_hidden_states, controlnet_ids=controlnet_ids, **input_shapes, ) _, neuron_outputs = export_models( models_and_neuron_configs=models_and_neuron_configs, task=task, output_dir=output, disable_neuron_cache=disable_neuron_cache, compiler_workdir=compiler_workdir, inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, output_file_names=output_model_names, compiler_kwargs=compiler_kwargs, model_name_or_path=model_name_or_path, ) # Validate compiled model if do_validation and tensor_parallel_size > 1: # TODO: support the validation of tp models. logger.warning( "The validation is not yet supported for tensor parallel model, the validation will be turned off." ) do_validation = False if do_validation is True: try: validate_models_outputs( models_and_neuron_configs=models_and_neuron_configs, neuron_named_outputs=neuron_outputs, output_dir=output, atol=atol, neuron_files_subpaths=output_model_names, ) logger.info( f"The {NEURON_COMPILER} export succeeded and the exported model was saved at: {output.as_posix()}" ) except ShapeError as e: raise e except AtolError as e: logger.warning( f"The {NEURON_COMPILER} export succeeded with the warning: {e}.\n The exported model was saved at: " f"{output.as_posix()}" ) except OutputMatchError as e: logger.warning( f"The {NEURON_COMPILER} export succeeded with the warning: {e}.\n The exported model was saved at: " f"{output.as_posix()}" ) except Exception as e: logger.error( f"An error occurred with the error message: {e}.\n The exported model was saved at: {output.as_posix()}" ) def maybe_export_from_neuron_model_class( model: str, output: Union[str, Path], task: str = "auto", cache_dir: Optional[str] = None, subfolder: str = "", trust_remote_code: bool = False, **kwargs, ): """Export the model from the neuron model class if it exists.""" if task == "auto": task = infer_task(model) output = Path(output) # Remove None values from the kwargs kwargs = {key: value for key, value in kwargs.items() if value is not None} # Also remove some arguments that are not supported in this context kwargs.pop("disable_neuron_cache", None) kwargs.pop("inline_weights_neff", None) kwargs.pop("O1", None) kwargs.pop("O2", None) kwargs.pop("O3", None) kwargs.pop("disable_validation", None) kwargs.pop("dynamic_batch_size", None) kwargs.pop("output_hidden_states", None) kwargs.pop("output_attentions", None) kwargs.pop("tensor_parallel_size", None) # Fetch the model config config = AutoConfig.from_pretrained(model) # Check if we have an auto-model class for the model_type and task if not has_neuron_model_class(model_type=config.model_type, task=task, mode="inference"): return False neuron_model_class = get_neuron_model_class(model_type=config.model_type, task=task, mode="inference") neuron_model = neuron_model_class.from_pretrained( model_id=model, export=True, cache_dir=cache_dir, subfolder=subfolder, config=config, trust_remote_code=trust_remote_code, load_weights=False, # Reduce model size for nxd models **kwargs, ) if not output.parent.exists(): output.parent.mkdir(parents=True) neuron_model.save_pretrained(output) try: tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=trust_remote_code) tokenizer.save_pretrained(output) except Exception: logger.info(f"No tokenizer found while exporting {model}.") return True def main(): parser = ArgumentParser(f"Hugging Face Optimum {NEURON_COMPILER} exporter") parse_args_neuron(parser) # Retrieve CLI arguments args = parser.parse_args() task = infer_task(args.model) if args.task == "auto" else args.task library_name = TasksManager.infer_library_from_model(args.model, cache_dir=args.cache_dir) if library_name == "diffusers": input_shapes = normalize_stable_diffusion_input_shapes(args) submodels = {"unet": args.unet} elif library_name == "sentence_transformers": input_shapes = normalize_sentence_transformers_input_shapes(args) submodels = None else: # New export mode using dedicated neuron model classes kwargs = vars(args).copy() if maybe_export_from_neuron_model_class(**kwargs): return # Fallback to legacy export input_shapes = get_input_shapes(task, args) submodels = None disable_neuron_cache = args.disable_neuron_cache compiler_kwargs = infer_compiler_kwargs(args) optional_outputs = customize_optional_outputs(args) optlevel = parse_optlevel(args) lora_args = LoRAAdapterArguments( model_ids=getattr(args, "lora_model_ids", None), weight_names=getattr(args, "lora_weight_names", None), adapter_names=getattr(args, "lora_adapter_names", None), scales=getattr(args, "lora_scales", None), ) ip_adapter_args = IPAdapterArguments( model_id=getattr(args, "ip_adapter_id", None), subfolder=getattr(args, "ip_adapter_subfolder", None), weight_name=getattr(args, "ip_adapter_weight_name", None), scale=getattr(args, "ip_adapter_scale", None), ) main_export( model_name_or_path=args.model, output=args.output, compiler_kwargs=compiler_kwargs, torch_dtype=args.torch_dtype, tensor_parallel_size=args.tensor_parallel_size, task=task, dynamic_batch_size=args.dynamic_batch_size, atol=args.atol, cache_dir=args.cache_dir, disable_neuron_cache=disable_neuron_cache, compiler_workdir=args.compiler_workdir, inline_weights_to_neff=args.inline_weights_neff, optlevel=optlevel, trust_remote_code=args.trust_remote_code, subfolder=args.subfolder, do_validation=not args.disable_validation, submodels=submodels, library_name=library_name, controlnet_ids=getattr(args, "controlnet_ids", None), lora_args=lora_args, ip_adapter_args=ip_adapter_args, **optional_outputs, **input_shapes, ) if __name__ == "__main__": main()