optimum/neuron/models/whisper/model.py (173 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Whisper model on Neuron devices.""" import logging from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from transformers import GenerationConfig, WhisperForConditionalGeneration from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from transformers.utils import ModelOutput from ....exporters.neuron import ( NeuronDefaultConfig, ) from ...modeling_seq2seq import NeuronModelForConditionalGeneration, _NeuronSeq2SeqModelPart from ...modeling_traced import NeuronTracedModel from ...utils import ( NEURON_FILE_NAME, is_neuronx_available, is_neuronx_distributed_available, ) from ...utils.doc import ( NEURON_AUDIO_SEQ2SEQ_INPUTS_DOCSTRING, NEURON_SEQ2SEQ_MODEL_START_DOCSTRING, add_start_docstrings, add_start_docstrings_to_model_forward, ) if TYPE_CHECKING: from transformers import PretrainedConfig if is_neuronx_available(): pass if is_neuronx_distributed_available(): pass logger = logging.getLogger(__name__) class DummyLayer: def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) def __call__(self, x): return x class NeuronWhisperEncoder(_NeuronSeq2SeqModelPart): """ Encoder and the 1st forward of decoder+language head. """ main_input_name = "input_features" def __init__( self, model: torch.jit._script.ScriptModule, parent_model: NeuronTracedModel, config: Optional["PretrainedConfig"] = None, neuron_config: Optional[Dict[str, str]] = None, ): super().__init__(model, parent_model, config, neuron_config, "encoder") stride = getattr(self.config, "stride", [1, 2]) self.conv1 = DummyLayer(stride=[stride[0]]) self.conv2 = DummyLayer(stride=[stride[1]]) def forward( self, input_features: torch.FloatTensor, decoder_input_ids: Optional[torch.LongTensor] = None, **kwargs, ): prepare_encoder_decoder_kwargs_for_generation = False if decoder_input_ids is None: decoder_input_ids = torch.full( (self.neuron_config.batch_size, 1), self.config.decoder_start_token_id, dtype=torch.long ) prepare_encoder_decoder_kwargs_for_generation = True outputs = self.model(input_features, decoder_input_ids) if prepare_encoder_decoder_kwargs_for_generation: return BaseModelOutput(last_hidden_state=outputs[1]) else: return outputs class NeuronWhisperDecoder(_NeuronSeq2SeqModelPart): """ Decoder with output embedding of the whisper model for Neuron inference. """ def __init__( self, model: torch.jit._script.ScriptModule, parent_model: NeuronTracedModel, config: Optional["PretrainedConfig"] = None, neuron_config: Optional[Dict[str, str]] = None, ): super().__init__(model, parent_model, config, neuron_config, "decoder") def forward( self, decoder_input_ids: Optional[torch.LongTensor], encoder_hidden_states: Optional[torch.FloatTensor], **kwargs, ): inputs = ( decoder_input_ids, encoder_hidden_states, ) outputs = self.model(*inputs) return (outputs, encoder_hidden_states) class NeuronWhisperModel: def __init__(self, encoder: NeuronWhisperEncoder, decoder: NeuronWhisperDecoder): self.encoder = encoder self.decoder = decoder @add_start_docstrings( """ Whisper Neuron model with a language modeling head that can be used for automatic speech recognition. """, NEURON_SEQ2SEQ_MODEL_START_DOCSTRING, ) class NeuronWhisperForConditionalGeneration(NeuronModelForConditionalGeneration, WhisperForConditionalGeneration): auto_model_class = WhisperForConditionalGeneration main_input_name = "input_features" encoder_class = NeuronWhisperEncoder decoder_class = NeuronWhisperDecoder def __init__( self, encoder: torch.jit._script.ScriptModule, decoder: torch.jit._script.ScriptModule, config: "PretrainedConfig", model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, encoder_file_name: Optional[str] = NEURON_FILE_NAME, decoder_file_name: Optional[str] = NEURON_FILE_NAME, preprocessors: Optional[List] = None, neuron_configs: Optional[Dict[str, "NeuronDefaultConfig"]] = None, configs: Optional[Dict[str, "PretrainedConfig"]] = None, generation_config: Optional[GenerationConfig] = None, **kwargs, ): super().__init__( encoder, decoder, config, model_save_dir, encoder_file_name, decoder_file_name, preprocessors, neuron_configs, configs, generation_config, **kwargs, ) self.model = NeuronWhisperModel(self.encoder, self.decoder) @property def device(self): return torch.device("cpu") def get_encoder(self) -> "NeuronWhisperEncoder": return self.encoder def _update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, num_new_tokens: int = 1, ) -> Dict[str, Any]: # Override "use_cache" to False, since whisper with cache is not yet supported for neuron. model_kwargs["use_cache"] = False model_kwargs = super()._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder, num_new_tokens ) return model_kwargs @add_start_docstrings_to_model_forward(NEURON_AUDIO_SEQ2SEQ_INPUTS_DOCSTRING) def forward( self, input_features: Optional[torch.FloatTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: if encoder_outputs is None: lm_logits, encoder_last_hidden_state = self.encoder( input_features=input_features, decoder_input_ids=decoder_input_ids ) else: # pad `decoder_input_ids` to the sequence length of the compilation decoder_input_ids_length = decoder_input_ids.shape[1] pad_size = torch.as_tensor(self.neuron_configs["decoder"].sequence_length - decoder_input_ids_length) decoder_input_ids = torch.nn.functional.pad( decoder_input_ids, (0, pad_size), "constant", self.preprocessors[0].pad_token_id ) lm_logits, encoder_last_hidden_state = self.decoder( decoder_input_ids=decoder_input_ids, encoder_hidden_states=encoder_outputs[0], ) # unpad lm_logits = lm_logits[:, :decoder_input_ids_length, :] return Seq2SeqLMOutput( logits=lm_logits, encoder_last_hidden_state=encoder_last_hidden_state, )