optimum/exporters/neuron/model_configs.py (896 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model specific Neuron configurations.""" import copy from functools import partial from typing import TYPE_CHECKING, Dict, List import torch from optimum.exporters.tasks import TasksManager from optimum.utils import ( DummyInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummyTextInputGenerator, DummyTimestepInputGenerator, DummyVisionInputGenerator, NormalizedConfig, NormalizedConfigManager, NormalizedSeq2SeqConfig, NormalizedTextAndVisionConfig, NormalizedTextConfig, NormalizedVisionConfig, is_diffusers_available, ) from ...neuron.utils import ( ASTDummyAudioInputGenerator, DummyBeamValuesGenerator, DummyControNetInputGenerator, DummyIPAdapterInputGenerator, DummyMaskedPosGenerator, WhisperDummyTextInputGenerator, is_neuronx_distributed_available, saved_model_in_temporary_directory, ) from .config import ( AudioNeuronConfig, TextAndVisionNeuronConfig, TextEncoderNeuronConfig, TextSeq2SeqNeuronConfig, VisionNeuronConfig, ) from .model_wrappers import ( CLIPVisionWithProjectionNeuronWrapper, ControlNetNeuronWrapper, NoCacheModelWrapper, PixartTransformerNeuronWrapper, SentenceTransformersCLIPNeuronWrapper, SentenceTransformersTransformerNeuronWrapper, T5DecoderWrapper, T5EncoderForSeq2SeqLMWrapper, T5EncoderWrapper, UnetNeuronWrapper, WhisperDecoderWrapper, WhisperEncoderWrapper, ) if is_neuronx_distributed_available(): import neuronx_distributed if TYPE_CHECKING: if is_diffusers_available(): from diffusers.models.vae import Decoder as VaeDecoder COMMON_TEXT_TASKS = [ "feature-extraction", "fill-mask", "multiple-choice", "question-answering", "text-classification", "token-classification", ] register_in_tasks_manager = TasksManager.create_register("neuron") @register_in_tasks_manager("bert", *COMMON_TEXT_TASKS) class BertNeuronConfig(TextEncoderNeuronConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("bert") ATOL_FOR_VALIDATION = 1e-3 @property def inputs(self) -> List[str]: return ["input_ids", "attention_mask", "token_type_ids"] @register_in_tasks_manager("albert", *COMMON_TEXT_TASKS) class AlbertNeuronConfig(BertNeuronConfig): pass @register_in_tasks_manager("convbert", *COMMON_TEXT_TASKS) class ConvBertNeuronConfig(BertNeuronConfig): ATOL_FOR_VALIDATION = 1e-1 # TODO: why accuracy more off than other arch @property def outputs(self) -> List[str]: if self.task == "feature-extraction": return ["last_hidden_state"] return self._TASK_TO_COMMON_OUTPUTS[self.task] @register_in_tasks_manager("electra", *COMMON_TEXT_TASKS) class ElectraNeuronConfig(BertNeuronConfig): @property def outputs(self) -> List[str]: if self.task == "feature-extraction": return ["last_hidden_state"] return self._TASK_TO_COMMON_OUTPUTS[self.task] @register_in_tasks_manager("esm", *["feature-extraction", "fill-mask", "text-classification", "token-classification"]) class EsmNeuronConfig(TextEncoderNeuronConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("bert") ATOL_FOR_VALIDATION = 1e-3 @property def inputs(self) -> List[str]: return ["input_ids", "attention_mask"] @register_in_tasks_manager("flaubert", *COMMON_TEXT_TASKS) class FlaubertNeuronConfig(ElectraNeuronConfig): pass @register_in_tasks_manager("mobilebert", *COMMON_TEXT_TASKS) class MobileBertNeuronConfig(BertNeuronConfig): pass @register_in_tasks_manager( "modernbert", *["feature-extraction", "fill-mask", "text-classification", "token-classification"] ) class ModernBertNeuronConfig(BertNeuronConfig): @property def inputs(self) -> List[str]: return ["input_ids", "attention_mask"] @property def outputs(self) -> List[str]: if self.task == "feature-extraction": return ["last_hidden_state"] return self._TASK_TO_COMMON_OUTPUTS[self.task] @register_in_tasks_manager("phi", *["feature-extraction", "text-classification", "token-classification"]) class PhiNeuronConfig(ElectraNeuronConfig): CUSTOM_MODEL_WRAPPER = NoCacheModelWrapper @property def inputs(self) -> List[str]: return ["input_ids", "attention_mask"] @register_in_tasks_manager("roformer", *COMMON_TEXT_TASKS) class RoFormerNeuronConfig(ElectraNeuronConfig): pass @register_in_tasks_manager("xlm", *COMMON_TEXT_TASKS) class XLMNeuronConfig(ElectraNeuronConfig): pass @register_in_tasks_manager("distilbert", *COMMON_TEXT_TASKS) class DistilBertNeuronConfig(BertNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 @property def inputs(self) -> List[str]: return ["input_ids", "attention_mask"] @property def outputs(self) -> List[str]: if self.task == "feature-extraction": return ["last_hidden_state"] return self._TASK_TO_COMMON_OUTPUTS[self.task] @register_in_tasks_manager("camembert", *COMMON_TEXT_TASKS) class CamembertNeuronConfig(BertNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 @property def inputs(self) -> List[str]: return ["input_ids", "attention_mask"] @register_in_tasks_manager("mpnet", *COMMON_TEXT_TASKS) class MPNetNeuronConfig(CamembertNeuronConfig): pass @register_in_tasks_manager("roberta", *COMMON_TEXT_TASKS) class RobertaNeuronConfig(CamembertNeuronConfig): pass @register_in_tasks_manager("xlm-roberta", *COMMON_TEXT_TASKS) class XLMRobertaNeuronConfig(CamembertNeuronConfig): pass # https://github.com/aws-neuron/aws-neuron-sdk/issues/642 # Failed only for INF1: 'XSoftmax' @register_in_tasks_manager("deberta", *([task for task in COMMON_TEXT_TASKS if task != "multiple-choice"])) class DebertaNeuronConfig(ElectraNeuronConfig): @property def inputs(self) -> List[str]: common_inputs = super().inputs if self._config.type_vocab_size == 0: # We remove token type ids. common_inputs.pop(-1) return common_inputs # https://github.com/aws-neuron/aws-neuron-sdk/issues/642 # Failed only for INF1: 'XSoftmax' @register_in_tasks_manager("deberta-v2", *([task for task in COMMON_TEXT_TASKS if task != "multiple-choice"])) class DebertaV2NeuronConfig(ElectraNeuronConfig): pass @register_in_tasks_manager( "transformer", *["feature-extraction", "sentence-similarity"], library_name="sentence_transformers" ) class SentenceTransformersTransformerNeuronConfig(TextEncoderNeuronConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig CUSTOM_MODEL_WRAPPER = SentenceTransformersTransformerNeuronWrapper ATOL_FOR_VALIDATION = 1e-3 @property def inputs(self) -> List[str]: return ["input_ids", "attention_mask"] @property def outputs(self) -> List[str]: return ["token_embeddings", "sentence_embedding"] class CLIPNormalizedConfig(NormalizedTextAndVisionConfig): TEXT_CONFIG = "text_config" VISION_CONFIG = "vision_config" @register_in_tasks_manager("clip-vision-with-projection", *["feature-extraction"], library_name="diffusers") class CLIPVisionWithProjectionNeuronConfig(VisionNeuronConfig): MODEL_TYPE = "clip-vision-with-projection" NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig CUSTOM_MODEL_WRAPPER = CLIPVisionWithProjectionNeuronWrapper @property def inputs(self) -> List[str]: return ["pixel_values"] @property def outputs(self) -> List[str]: common_outputs = ["image_embeds", "last_hidden_state"] if self.output_hidden_states: common_outputs.append("hidden_states") return common_outputs @register_in_tasks_manager("clip", *["feature-extraction", "zero-shot-image-classification", "image-classification"]) class CLIPNeuronConfig(TextAndVisionNeuronConfig): NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig INPUT_ARGS = ("text_batch_size", "image_batch_size", "sequence_length", "num_channels", "width", "height") @property def inputs(self) -> List[str]: if self.task == "image-classification": return ["pixel_values"] else: return ["input_ids", "pixel_values", "attention_mask"] @property def outputs(self) -> List[str]: if self.task == "image-classification": return ["logits"] else: return [ "logits_per_image", "logits_per_text", "text_embeds", "image_embeds", "text_model_output", "vision_model_output", ] def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: for name, axis_dim in self._axes.items(): self._axes[name] = kwargs.pop(name, axis_dim) self._validate_mandatory_axes() other_axes = copy.deepcopy(self._axes) text_batch_size = other_axes.pop("text_batch_size") images_batch_size = other_axes.pop("image_batch_size") return [ DummyTextInputGenerator(self.task, self._normalized_config, batch_size=text_batch_size, **other_axes), DummyVisionInputGenerator(self.task, self._normalized_config, batch_size=images_batch_size, **other_axes), ] @register_in_tasks_manager("clip-text-with-projection", *["feature-extraction"], library_name="diffusers") class CLIPTextWithProjectionNeuronConfig(TextEncoderNeuronConfig): MODEL_TYPE = "clip-text-with-projection" ATOL_FOR_VALIDATION = 1e-3 NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( vocab_size="vocab_size", sequence_length="max_position_embeddings", num_layers="num_hidden_layers", allow_new=True, ) @property def inputs(self) -> List[str]: return ["input_ids"] @property def outputs(self) -> List[str]: common_outputs = ["text_embeds", "last_hidden_state"] if self._normalized_config.output_hidden_states: common_outputs.append("hidden_states") return common_outputs @register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="diffusers") class CLIPTextNeuronConfig(CLIPTextWithProjectionNeuronConfig): MODEL_TYPE = "clip-text-model" @property def outputs(self) -> List[str]: common_outputs = ["last_hidden_state", "pooler_output"] if self._normalized_config.output_hidden_states: common_outputs.append("hidden_states") return common_outputs # TODO: We should decouple clip text and vision, this would need fix on Optimum main. For the current workaround # users can pass dummy text inputs when encoding image, vice versa. @register_in_tasks_manager( "clip", *["feature-extraction", "sentence-similarity"], library_name="sentence_transformers" ) class SentenceTransformersCLIPNeuronConfig(CLIPNeuronConfig): CUSTOM_MODEL_WRAPPER = SentenceTransformersCLIPNeuronWrapper ATOL_FOR_VALIDATION = 1e-3 @property def outputs(self) -> List[str]: return ["text_embeds", "image_embeds"] @register_in_tasks_manager("vit", *["feature-extraction", "image-classification"]) class ViTNeuronConfig(VisionNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyMaskedPosGenerator) INPUT_ARGS = ("batch_size",) # `num_channels` and `image_size` are inferred from the config @property def inputs(self) -> List[str]: common_inputs = ["pixel_values"] if self.task == "masked-im": common_inputs.append("bool_masked_pos") return common_inputs @register_in_tasks_manager("beit", *["feature-extraction", "image-classification"]) class BeitNeuronConfig(ViTNeuronConfig): pass @register_in_tasks_manager("convnext", *["feature-extraction", "image-classification"]) class ConvNextNeuronConfig(ViTNeuronConfig): pass @register_in_tasks_manager("convnextv2", *["feature-extraction", "image-classification"]) class ConvNextV2NeuronConfig(ViTNeuronConfig): pass @register_in_tasks_manager("cvt", *["feature-extraction", "image-classification"]) class CvTNeuronConfig(ViTNeuronConfig): MODEL_TYPE = "cvt" @property def outputs(self) -> List[str]: common_outputs = super().outputs if self.task == "feature-extraction": return ["last_hidden_state", "cls_token_value"] else: return common_outputs @register_in_tasks_manager("deit", *["feature-extraction", "image-classification"]) class DeiTNeuronConfig(ViTNeuronConfig): pass @register_in_tasks_manager("donut-swin", *["feature-extraction"]) class DonutSwinNeuronConfig(ViTNeuronConfig): pass @register_in_tasks_manager("dpt", *["feature-extraction"]) class DptNeuronConfig(ViTNeuronConfig): pass @register_in_tasks_manager("levit", *["feature-extraction", "image-classification"]) class LevitNeuronConfig(ViTNeuronConfig): MODEL_TYPE = "levit" pass @register_in_tasks_manager( "mobilenet-v2", *["feature-extraction", "image-classification", "semantic-segmentation", "image-segmentation"] ) class MobileNetV2NeuronConfig(ViTNeuronConfig): MODEL_TYPE = "mobilenet-v2" pass @register_in_tasks_manager( "mobilevit", *["feature-extraction", "image-classification", "semantic-segmentation", "image-segmentation"] ) class MobileViTNeuronConfig(ViTNeuronConfig): MODEL_TYPE = "mobilevit" pass @register_in_tasks_manager("swin", *["feature-extraction", "image-classification"]) class SwinNeuronConfig(ViTNeuronConfig): pass @register_in_tasks_manager("yolos", *["feature-extraction", "object-detection"]) class YolosTNeuronConfig(ViTNeuronConfig): @property def outputs(self) -> List[str]: common_outputs = super().outputs if self.task == "object-detection": common_outputs.append("last_hidden_state") return common_outputs @register_in_tasks_manager( "wav2vec2", *[ "feature-extraction", "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", ], ) class Wav2Vec2NeuronConfig(AudioNeuronConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfig MODEL_TYPE = "wav2vec2" @property def inputs(self) -> List[str]: return ["input_values"] @property def outputs(self) -> List[str]: common_outputs = super().outputs if self.task == "feature-extraction": common_outputs = ["last_hidden_state", "extract_features"] if self.task == "audio-xvector": common_outputs.append("embeddings") return common_outputs @register_in_tasks_manager( "audio-spectrogram-transformer", *[ "feature-extraction", "audio-classification", ], ) class ASTNeuronConfig(AudioNeuronConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( num_mel_bins="num_mel_bins", max_length="max_length", allow_new=True ) DUMMY_INPUT_GENERATOR_CLASSES = (ASTDummyAudioInputGenerator,) @property def inputs(self) -> List[str]: return ["input_values"] @register_in_tasks_manager( "hubert", *[ "feature-extraction", "automatic-speech-recognition", "audio-classification", ], ) class HubertNeuronConfig(Wav2Vec2NeuronConfig): MODEL_TYPE = "hubert" @property def outputs(self) -> List[str]: common_outputs = super().outputs if self.task == "feature-extraction": common_outputs = ["last_hidden_state"] return common_outputs # TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398. # @register_in_tasks_manager( # "sew", # *[ # "feature-extraction", # "automatic-speech-recognition", # "audio-classification", # ], # ) # class SEWNeuronConfig(Wav2Vec2NeuronConfig): # pass # TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398. # @register_in_tasks_manager( # "sew-d", # *[ # "feature-extraction", # "automatic-speech-recognition", # "audio-classification", # ], # ) # class SEWDNeuronConfig(Wav2Vec2NeuronConfig): # pass @register_in_tasks_manager( "unispeech", *[ "feature-extraction", "automatic-speech-recognition", "audio-classification", ], ) class UniSpeechNeuronConfig(Wav2Vec2NeuronConfig): MODEL_TYPE = "unispeech" pass @register_in_tasks_manager( "unispeech-sat", *[ "feature-extraction", "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", ], ) class UniSpeechSATNeuronConfig(Wav2Vec2NeuronConfig): MODEL_TYPE = "unispeech-sat" pass # TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398. # @register_in_tasks_manager( # "wav2vec2-bert", # *[ # "feature-extraction", # "automatic-speech-recognition", # "audio-classification", # "audio-frame-classification", # "audio-xvector", # ], # ) # class Wav2Vec2BertNeuronConfig(Wav2Vec2NeuronConfig): # pass # TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398. # @register_in_tasks_manager( # "wav2vec2-conformer", # *[ # "feature-extraction", # "automatic-speech-recognition", # "audio-classification", # "audio-frame-classification", # "audio-xvector", # ], # ) # class Wav2Vec2ConformerNeuronConfig(Wav2Vec2NeuronConfig): # pass @register_in_tasks_manager( "wavlm", *[ "feature-extraction", "automatic-speech-recognition", "audio-classification", "audio-frame-classification", "audio-xvector", ], ) class WavLMNeuronConfig(Wav2Vec2NeuronConfig): MODEL_TYPE = "wavlm" pass @register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers") class UNetNeuronConfig(VisionNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 INPUT_ARGS = ("batch_size", "sequence_length", "num_channels", "width", "height", "vae_scale_factor") MODEL_TYPE = "unet" CUSTOM_MODEL_WRAPPER = UnetNeuronWrapper NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( height="height", width="width", num_channels="in_channels", hidden_size="cross_attention_dim", vocab_size="norm_num_groups", allow_new=True, ) DUMMY_INPUT_GENERATOR_CLASSES = ( DummyVisionInputGenerator, DummyTimestepInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummyControNetInputGenerator, DummyIPAdapterInputGenerator, ) @property def inputs(self) -> List[str]: common_inputs = ["sample", "timestep", "encoder_hidden_states"] # TODO : add text_image, image and image_embeds if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": common_inputs.append("text_embeds") common_inputs.append("time_ids") if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None: common_inputs.append("timestep_cond") if self.with_controlnet: # outputs of controlnet common_inputs += ["down_block_additional_residuals", "mid_block_additional_residual"] if self.with_ip_adapter: # add output of image encoder if self.image_encoder_output_hidden_states: common_inputs += ["image_enc_hidden_states"] else: common_inputs += ["image_embeds"] return common_inputs @property def outputs(self) -> List[str]: return ["sample"] def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): dummy_inputs = super().generate_dummy_inputs(**kwargs) dummy_inputs["timestep"] = dummy_inputs["timestep"].float() dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] # break down down_block_additional_residuals num_down_block_outputs = len(self._normalized_config.down_block_types) * ( self._normalized_config.layers_per_block + 1 ) down_block_additional_residuals = dummy_inputs.pop("down_block_additional_residuals", None) if down_block_additional_residuals: for idx in range(num_down_block_outputs): dummy_inputs[f"down_block_additional_residuals_{idx}"] = down_block_additional_residuals[idx] if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": dummy_inputs["added_cond_kwargs"] = { "text_embeds": dummy_inputs.pop("text_embeds"), "time_ids": dummy_inputs.pop("time_ids"), } if return_tuple is True: return tuple(dummy_inputs.values()) else: return dummy_inputs @property def is_sdxl(self) -> bool: return self._is_sdxl @is_sdxl.setter def is_sdxl(self, is_sdxl: bool): self._is_sdxl = is_sdxl @property def with_controlnet(self) -> bool: return self._with_controlnet @with_controlnet.setter def with_controlnet(self, with_controlnet: bool): self._with_controlnet = with_controlnet @property def with_ip_adapter(self) -> bool: return self._with_ip_adapter @with_ip_adapter.setter def with_ip_adapter(self, with_ip_adapter: bool): self._with_ip_adapter = with_ip_adapter if with_ip_adapter: self.mandatory_axes += ("image_encoder_shapes",) setattr(self, "image_encoder_shapes", self.input_shapes["image_encoder_shapes"]) @register_in_tasks_manager("pixart-transformer-2d", *["semantic-segmentation"], library_name="diffusers") class PixartTransformerNeuronConfig(VisionNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 INPUT_ARGS = ( "batch_size", "sequence_length", "num_channels", "width", "height", "vae_scale_factor", "encoder_hidden_size", ) MODEL_TYPE = "pixart-transformer-2d" CUSTOM_MODEL_WRAPPER = PixartTransformerNeuronWrapper NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( height="height", width="width", num_channels="in_channels", hidden_size="cross_attention_dim", vocab_size="norm_num_groups", allow_new=True, ) DUMMY_INPUT_GENERATOR_CLASSES = ( DummyVisionInputGenerator, DummyControNetInputGenerator, DummyTextInputGenerator, DummySeq2SeqDecoderTextInputGenerator, ) @property def inputs(self) -> List[str]: common_inputs = ["sample", "encoder_hidden_states", "timestep", "encoder_attention_mask"] return common_inputs @property def outputs(self) -> List[str]: return ["out_hidden_states"] @register_in_tasks_manager("controlnet", *["semantic-segmentation"], library_name="diffusers") class ControlNetNeuronConfig(VisionNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 INPUT_ARGS = ( "batch_size", "sequence_length", "num_channels", "height", "width", "vae_scale_factor", "encoder_hidden_size", ) MODEL_TYPE = "controlnet" CUSTOM_MODEL_WRAPPER = ControlNetNeuronWrapper NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( height="height", width="width", num_channels="in_channels", hidden_size="cross_attention_dim", vocab_size="norm_num_groups", allow_new=True, ) DUMMY_INPUT_GENERATOR_CLASSES = ( DummyVisionInputGenerator, DummyControNetInputGenerator, # Instead of `encoder_hidden_states` generated by `DummySeq2SeqDecoderTextInputGenerator` DummyTimestepInputGenerator, DummySeq2SeqDecoderTextInputGenerator, ) @property def inputs(self) -> List[str]: common_inputs = ["sample", "timestep", "encoder_hidden_states", "controlnet_cond", "conditioning_scale"] if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": common_inputs.append("text_embeds") common_inputs.append("time_ids") return common_inputs @property def outputs(self) -> List[str]: return ["down_block_res_samples", "mid_block_res_sample"] @register_in_tasks_manager("vae-encoder", *["semantic-segmentation"], library_name="diffusers") class VaeEncoderNeuronConfig(VisionNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 MODEL_TYPE = "vae-encoder" NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( num_channels="in_channels", allow_new=True, ) @property def inputs(self) -> List[str]: return ["sample"] @property def outputs(self) -> List[str]: return ["latent_parameters"] def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): dummy_inputs = super().generate_dummy_inputs(**kwargs) if return_tuple is True: return tuple(dummy_inputs.values()) else: return dummy_inputs @register_in_tasks_manager("vae-decoder", *["semantic-segmentation"], library_name="diffusers") class VaeDecoderNeuronConfig(VisionNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 MODEL_TYPE = "vae-decoder" NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( num_channels="latent_channels", allow_new=True, ) @property def inputs(self) -> List[str]: return ["latent_sample"] @property def outputs(self) -> List[str]: return ["sample"] def patch_model_for_export( self, model: "VaeDecoder", dummy_inputs: Dict[str, torch.Tensor], **kwargs, ): return super().patch_model_for_export(model=model, dummy_inputs=dummy_inputs, forward_with_tuple=True) class T5EncoderBaseNeuronConfig(TextSeq2SeqNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( hidden_size="d_model", num_attention_heads="num_heads", encoder_num_layers="num_layers", decoder_num_layers="num_decoder_layers", key_value_dim="d_kv", allow_new=True, ) @property def inputs(self) -> List[str]: return ["input_ids", "attention_mask"] @register_in_tasks_manager("t5", *["feature-extraction"], library_name="diffusers") class T5EncoderForDiffusersNeuronConfig(T5EncoderBaseNeuronConfig): CUSTOM_MODEL_WRAPPER = T5EncoderWrapper INPUT_ARGS = ("batch_size", "sequence_length") @property def outputs(self) -> List[str]: return ["last_hidden_state"] @property def is_encoder_decoder(self) -> bool: return True def patch_model_for_export(self, model_or_path, **input_shapes): return self.CUSTOM_MODEL_WRAPPER(model_or_path, **input_shapes) @register_in_tasks_manager("t5-encoder", *["text2text-generation"]) class T5EncoderForTransformersNeuronConfig(T5EncoderBaseNeuronConfig): CUSTOM_MODEL_WRAPPER = T5EncoderForSeq2SeqLMWrapper INPUT_ARGS = ("batch_size", "sequence_length", "num_beams") MODEL_TYPE = "t5-encoder" @property def outputs(self) -> List[str]: common_outputs = ( [f"present.{idx}.self.key" for idx in range(self._config.num_decoder_layers)] + [f"present.{idx}.self.value" for idx in range(self._config.num_decoder_layers)] + [f"present.{idx}.cross.key" for idx in range(self._config.num_decoder_layers)] + [f"present.{idx}.cross.value" for idx in range(self._config.num_decoder_layers)] ) return common_outputs @property def is_encoder_decoder(self) -> bool: return True def patch_model_for_export(self, model_or_path, device="xla", **kwargs): num_beams = kwargs.pop("num_beams", 1) sequence_length = kwargs.pop("sequence_length", None) batch_size = kwargs.pop("batch_size", None) if self.tensor_parallel_size > 1: # `torch.nn.modules` objects not eligible for pickling, the model needs to be loaded within the func. return partial( self.get_parallel_callable, model_or_path, sequence_length, batch_size, num_beams, device, self.tensor_parallel_size, ) else: return self.CUSTOM_MODEL_WRAPPER( model_or_path, sequence_length=sequence_length, batch_size=batch_size, num_beams=num_beams, device=device, tensor_parallel_size=self.tensor_parallel_size, ) def get_parallel_callable( self, model_name_or_path, sequence_length, batch_size, num_beams, device, tensor_parallel_size ): """Unlike `torch_neuronx.trace`, `parallel_model_trace` requires a function returning a model object and a dictionary of states.""" model = TasksManager.get_model_from_task( model_name_or_path=model_name_or_path, task=self.task, framework="pt", library_name="transformers", ) # TODO: add extra args, eg. revision, trust_remote_code, etc. model.config.use_cache = True with saved_model_in_temporary_directory(model) as ckpt_path: # Plug in parallel layers from optimum.neuron.models.inference.t5.modeling_t5 import parallelize parallel_model = parallelize(model) # Load the weights into the parallel layers neuronx_distributed.parallel_layers.load(ckpt_path, parallel_model, sharded=False) encoder = self.CUSTOM_MODEL_WRAPPER( parallel_model, sequence_length=sequence_length, batch_size=batch_size, num_beams=num_beams, device=device, tensor_parallel_size=tensor_parallel_size, ) encoder.eval() aliases = self.generate_io_aliases(encoder) return encoder, aliases def generate_io_aliases(self, encoder=None): aliases = {} if self.tensor_parallel_size > 1: for i in range(len(encoder.past_key_values_sa)): aliases[encoder.past_key_values_sa[i]] = i for i in range(len(encoder.past_key_values_ca)): aliases[encoder.past_key_values_ca[i]] = len(encoder.past_key_values_sa) + i return aliases @register_in_tasks_manager("t5-decoder", "text2text-generation") class T5DecoderNeuronConfig(TextSeq2SeqNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 DUMMY_INPUT_GENERATOR_CLASSES = TextSeq2SeqNeuronConfig.DUMMY_INPUT_GENERATOR_CLASSES + (DummyBeamValuesGenerator,) INPUT_ARGS = ("batch_size", "sequence_length", "num_beams") MODEL_TYPE = "t5-decoder" CUSTOM_MODEL_WRAPPER = T5DecoderWrapper NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig @property def inputs(self) -> List[str]: common_inputs = [ "decoder_input_ids", "decoder_attention_mask", "encoder_hidden_states", "attention_mask", # TODO: replace with `encoder_attention_mask` after optimum 1.14 release "beam_idx", "beam_scores", ] return common_inputs @property def outputs(self) -> List[str]: beam_outputs = ["next_token_scores", "next_tokens", "next_indices"] if self.num_beams > 1 else ["next_tokens"] common_outputs = ( beam_outputs + [f"past.{idx}.self.key" for idx in range(self._config.num_decoder_layers)] + [f"past.{idx}.self.value" for idx in range(self._config.num_decoder_layers)] + [f"past.{idx}.cross.key" for idx in range(self._config.num_decoder_layers)] + [f"past.{idx}.cross.value" for idx in range(self._config.num_decoder_layers)] ) if self.output_hidden_states: # Flatten hidden states of all layers common_outputs += [ f"decoder_hidden_state.{idx}" for idx in range(self._config.num_decoder_layers + 1) ] # +1 for the embedding layer if self.output_attentions: # Flatten attentions tensors of all attention layers common_outputs += [f"decoder_attention.{idx}" for idx in range(self._config.num_decoder_layers)] if getattr(self._config, "is_encoder_decoder", False) is True: common_outputs += [f"cross_attention.{idx}" for idx in range(self._config.num_decoder_layers)] return common_outputs @property def is_encoder_decoder(self) -> bool: return True def generate_dummy_inputs(self, **kwargs): batch_size = kwargs.pop("batch_size") * kwargs.get("num_beams") dummy_inputs = super().generate_dummy_inputs(batch_size=batch_size, **kwargs) dummy_inputs["decoder_input_ids"] = dummy_inputs["decoder_input_ids"][:, :1] # sequence_length = 1 dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] return dummy_inputs def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: dummy_inputs_generators = super()._create_dummy_input_generator_classes(**kwargs) dummy_beam_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[-1]( self.task, self._normalized_config, num_beams=kwargs.pop("num_beams", 1), **kwargs, ) dummy_inputs_generators.append(dummy_beam_values_generator) return dummy_inputs_generators def patch_model_for_export(self, model, device="xla", **kwargs): batch_size = kwargs.pop("batch_size", 1) sequence_length = kwargs.pop("sequence_length", 1) num_beams = kwargs.pop("num_beams", 1) trace_args = { "model": model, "batch_size": batch_size, "sequence_length": sequence_length, "num_beams": num_beams, "output_hidden_states": self.output_hidden_states, "output_attentions": self.output_attentions, "device": device, "tensor_parallel_size": self.tensor_parallel_size, } if self.tensor_parallel_size > 1: return partial( self.get_parallel_callable, model, batch_size, sequence_length, num_beams, self.output_hidden_states, self.output_attentions, device, self.tensor_parallel_size, ) else: return self.CUSTOM_MODEL_WRAPPER(**trace_args) def get_parallel_callable( self, model_name_or_path, batch_size, sequence_length, num_beams, output_hidden_states, output_attentions, device, tensor_parallel_size, ): """Unlike `torch_neuronx.trace`, `parallel_model_trace` requires a function returning a model object and a dictionary of states.""" model = TasksManager.get_model_from_task( model_name_or_path=model_name_or_path, task=self.task, framework="pt", library_name="transformers", ) # TODO: add extra args, eg. revision, trust_remote_code, etc. model.config.use_cache = True with saved_model_in_temporary_directory(model) as ckpt_path: # Plug in parallel layers from optimum.neuron.models.inference.t5.modeling_t5 import parallelize parallel_model = parallelize(model) # Load the weights into the parallel layers neuronx_distributed.parallel_layers.load(ckpt_path, parallel_model, sharded=False) decoder = self.CUSTOM_MODEL_WRAPPER( parallel_model, batch_size=batch_size, sequence_length=sequence_length, num_beams=num_beams, output_hidden_states=output_hidden_states, output_attentions=output_attentions, device=device, tensor_parallel_size=tensor_parallel_size, ) decoder.eval() aliases = self.generate_io_aliases(decoder) return decoder, aliases def generate_io_aliases(self, decoder): num_outputs_from_trace = 3 if decoder.num_beams > 1 else 1 aliases = {} for i in range(len(decoder.past_key_values_sa)): aliases[decoder.past_key_values_sa[i]] = i + num_outputs_from_trace for i in range(len(decoder.past_key_values_ca)): aliases[decoder.past_key_values_ca[i]] = len(decoder.past_key_values_sa) + i + num_outputs_from_trace return aliases @register_in_tasks_manager("whisper-encoder", *["automatic-speech-recognition"]) class WhisperEncoderNeuronConfig(AudioNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 MODEL_TYPE = "whisper-encoder" CUSTOM_MODEL_WRAPPER = WhisperEncoderWrapper INPUT_ARGS = ("batch_size", "sequence_length") DUMMY_INPUT_GENERATOR_CLASSES = AudioNeuronConfig.DUMMY_INPUT_GENERATOR_CLASSES + (WhisperDummyTextInputGenerator,) NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( encoder_num_layers="encoder_layers", decoder_num_layers="decoder_layers", feature_size="num_mel_bins", allow_new=True, ) @property def inputs(self) -> List[str]: return ["input_features", "decoder_input_ids"] @property def outputs(self) -> List[str]: return ["lm_logits", "encoder_last_hidden_state"] @property def is_encoder_decoder(self) -> bool: return True def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): kwargs["sequence_length"] = 1 # only `decoder_start_token_id` return super().generate_dummy_inputs(return_tuple=return_tuple, **kwargs) def patch_model_for_export(self, model_or_path, **input_shapes): return self.CUSTOM_MODEL_WRAPPER(model_or_path, **input_shapes) @register_in_tasks_manager("whisper-decoder", *["automatic-speech-recognition"]) class WhisperDecoderNeuronConfig(AudioNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 MODEL_TYPE = "whisper-decoder" DUMMY_INPUT_GENERATOR_CLASSES = (WhisperDummyTextInputGenerator,) INPUT_ARGS = ("batch_size", "sequence_length") CUSTOM_MODEL_WRAPPER = WhisperDecoderWrapper NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( encoder_num_layers="encoder_layers", decoder_num_layers="decoder_layers", feature_size="num_mel_bins", hidden_size="d_model", allow_new=True, ) @property def inputs(self) -> List[str]: return ["decoder_input_ids", "encoder_hidden_states"] @property def outputs(self) -> List[str]: return ["lm_logits"] @property def is_encoder_decoder(self) -> bool: return True def patch_model_for_export(self, model_or_path, **input_shapes): return self.CUSTOM_MODEL_WRAPPER(model_or_path, **input_shapes)