optimum/intel/openvino/modeling_text2speech.py (436 lines of code) (raw):
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
import numpy as np
import openvino
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from torch import nn
from transformers import (
AutoConfig,
AutoModelForTextToSpectrogram,
GenerationConfig,
PretrainedConfig,
)
from transformers.file_utils import add_start_docstrings
from transformers.utils import ModelOutput
from ...exporters.openvino.stateful import model_has_state
from .configuration import OVConfig, OVWeightQuantizationConfig
from .modeling_base import OVBaseModel, OVModelPart
from .modeling_seq2seq import (
INPUTS_DOCSTRING,
OVModelForSeq2SeqLM,
)
from .utils import TemporaryDirectory
logger = logging.getLogger(__name__)
class OVTextToSpeechEncoder(OVModelPart):
_model_name = "encoder"
def __init__(self, model: openvino.Model, parent_model: OVBaseModel) -> None:
super().__init__(model, parent_model, model_name=self._model_name)
self.output_dtypes = {key.get_any_name(): key.get_element_type().get_type_name() for key in self.model.outputs}
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.hidden_states_output_names = []
self._main_input = list(self.input_names.keys())[0]
def forward(self, input_ids, **kwargs):
self._compile()
inputs = {self._main_input: input_ids}
result = self.request(inputs)
last_hidden_state = torch.from_numpy(result[0])
encoder_attention_mask = torch.from_numpy(result[1])
return ModelOutput(last_hidden_state=last_hidden_state, encoder_attention_mask=encoder_attention_mask)
class OVTextToSpeechDecoder(OVModelPart):
_model_name = "decoder"
def __init__(self, model: openvino.Model, parent_model: OVBaseModel) -> None:
super().__init__(model, parent_model, model_name=self._model_name)
self.output_dtypes = {key.get_any_name(): key.get_element_type().get_type_name() for key in self.model.outputs}
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.hidden_states_output_names = []
if len(self.model.outputs) > 2:
self.hidden_states_output_names = [
key.get_any_name() for key in self.model.outputs[2:] if "hidden_states" in key.get_any_name()
]
def forward(self, inputs_embeds, speaker_embeddings, encoder_last_hidden_state, encoder_attention_mask, **kwargs):
self._compile()
bsz = inputs_embeds.size(0)
inputs = {
"inputs_embeds": inputs_embeds,
"speaker_embeddings": speaker_embeddings,
"encoder_hidden_states": encoder_last_hidden_state,
"encoder_attention_mask": encoder_attention_mask,
"beam_idx": np.arange(bsz, dtype=np.int32),
}
result = self.request(inputs)
output_sequence_out = torch.from_numpy(result[0])
spectrum = torch.from_numpy(result[1])
prob = torch.from_numpy(result[2])
return ModelOutput(output_sequence_out=output_sequence_out, spectrum=spectrum, prob=prob)
class OVTextToSpeechPostNet(OVModelPart):
_model_name = "postnet"
def __init__(self, model: openvino.Model, parent_model: OVBaseModel) -> None:
super().__init__(model, parent_model, model_name=self._model_name)
self.output_dtypes = {key.get_any_name(): key.get_element_type().get_type_name() for key in self.model.outputs}
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.hidden_states_output_names = []
if len(self.model.outputs) > 2:
self.hidden_states_output_names = [
key.get_any_name() for key in self.model.outputs[2:] if "hidden_states" in key.get_any_name()
]
def forward(self, spectrograms, **kwargs):
self._compile()
inputs = {
"raw_spectrogram": spectrograms,
}
result = self.request(inputs)
postnet_spectrogram = torch.from_numpy(result[0])
return ModelOutput(postnet_spectrogram=postnet_spectrogram)
class OVTextToSpeechVocoder(OVModelPart):
_model_name = "vocoder"
def __init__(self, model: openvino.Model, parent_model: OVBaseModel) -> None:
super().__init__(model, parent_model, model_name=self._model_name)
self.output_dtypes = {key.get_any_name(): key.get_element_type().get_type_name() for key in self.model.outputs}
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.hidden_states_output_names = []
if len(self.model.outputs) > 2:
self.hidden_states_output_names = [
key.get_any_name() for key in self.model.outputs[2:] if "hidden_states" in key.get_any_name()
]
def forward(self, spectrogram, **kwargs):
self._compile()
inputs = {
"spectrogram": spectrogram,
}
result = self.request(inputs)
waveform = torch.from_numpy(result[0])
return ModelOutput(waveform=waveform)
@add_start_docstrings(
"""
This class provides interface to export and infer text-to-speech models using OpenVINO.
""",
INPUTS_DOCSTRING,
)
class OVModelForTextToSpeechSeq2Seq(OVModelForSeq2SeqLM):
auto_model_class = AutoModelForTextToSpectrogram
export_feature = "text-to-audio"
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
**kwargs,
):
if "SpeechT5ForTextToSpeech" in config.architectures:
return _OVModelForSpeechT5ForTextToSpeech._from_pretrained(model_id, config, **kwargs)
else:
raise ValueError(f"{config.architectures} are not supported text-to-audio model using OpenVINO")
return super()._from_pretrained(model_id, config, **kwargs)
class _OVModelForSpeechT5ForTextToSpeech(OVModelForTextToSpeechSeq2Seq):
"""
This class implements an own generate method since we split the pipeline more compact
to have encoder, decoder, postnet, and vocoder
"""
main_input_name = "input_ids"
OV_ENCODER_MODEL_NAME = "openvino_encoder_model.xml"
OV_DECODER_MODEL_NAME = "openvino_decoder_model.xml"
OV_POSTNET_MODEL_NAME = "openvino_postnet.xml"
OV_VOCODER_MODEL_NAME = "openvino_vocoder.xml"
_supports_cache_class = True
def __init__(
self,
encoder: openvino.Model,
decoder: openvino.Model,
postnet: openvino.Model,
vocoder: openvino.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
self.config = config
self.use_cache = model_has_state(decoder)
self._model_save_dir = model_save_dir
self._device = device.upper()
self.is_dynamic = dynamic_shapes
self.ov_config = {} if ov_config is None else {**ov_config}
self.preprocessors = kwargs.get("preprocessors", [])
self._supports_cache_class = False
self.main_input_name = "input_ids"
self._compile_only = kwargs.get("compile_only", False)
enable_compilation = kwargs.get("compile", True)
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
self._openvino_config = None
if quantization_config:
self._openvino_config = OVConfig(quantization_config=quantization_config)
self._set_ov_config_parameters()
self.encoder = OVTextToSpeechEncoder(encoder, self)
self.decoder = OVTextToSpeechDecoder(decoder, self)
self.postnet = OVTextToSpeechPostNet(postnet, self)
self.vocoder = OVTextToSpeechVocoder(vocoder, self)
if enable_compilation and not self._compile_only:
self.compile()
# Avoid warnings when creating a transformers pipeline
AutoConfig.register(self.base_model_prefix, AutoConfig)
try:
self.auto_model_class.register(AutoConfig, self.__class__)
except AttributeError:
pass
def clear_requests(self):
if self._compile_only:
raise ValueError(
"`clear_requests()` is not supported with `compile_only` mode, please initialize model without this option"
)
for _, component in self.components.items():
component.clear_requests()
def compile(self):
for _, component in self.components.items():
if isinstance(component, OVModelPart):
component._compile()
else:
component.compile()
@property
def _ov_submodel_names(self):
component_names = ["encoder", "decoder", "postnet", "vocoder"]
return component_names
@property
def components(self):
return {component_name: getattr(self, component_name) for component_name in self._ov_submodel_names}
@property
def ov_submodels(self) -> Dict[str, openvino.Model]:
return {component_name: getattr(self, component_name).model for component_name in self._ov_submodel_names}
def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Saves the model to the OpenVINO IR format so that it can be re-loaded using the
[`~optimum.intel.openvino.modeling.OVModel.from_pretrained`] class method.
Arguments:
save_directory (`str` or `Path`):
The directory where to save the model files.
"""
src_models = list(self.ov_submodels.values())
dst_file_names = [
self.OV_ENCODER_MODEL_NAME,
self.OV_DECODER_MODEL_NAME,
self.OV_POSTNET_MODEL_NAME,
self.OV_VOCODER_MODEL_NAME,
]
for src_model, dst_file_name in zip(src_models, dst_file_names):
dst_path = os.path.join(save_directory, dst_file_name)
openvino.save_model(src_model, dst_path, compress_to_fp16=False)
self._save_openvino_config(save_directory)
if self.generation_config is not None:
try:
self.generation_config.save_pretrained(save_directory)
except Exception as exception:
logger.warning(
f"The generation config will not be saved, saving failed with following error:\n{exception}"
)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
local_files_only: bool = False,
load_in_8bit: bool = False,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
device = kwargs.pop("device", "CPU")
dynamic_shapes = kwargs.pop("dynamic_shapes", True)
ov_config = kwargs.pop("ov_config", None)
generation_config = kwargs.pop("generation_config", None)
preprocessors = kwargs.pop("preprocessors", [])
compile_only = kwargs.pop("compile_only", False)
enable_compilation = kwargs.pop("compile", True)
model_file_names = {
"encoder_model": cls.OV_ENCODER_MODEL_NAME,
"encoder_model_bin": cls.OV_ENCODER_MODEL_NAME.replace(".xml", ".bin"),
"decoder_model": cls.OV_DECODER_MODEL_NAME,
"decoder_model_bin": cls.OV_DECODER_MODEL_NAME.replace(".xml", ".bin"),
"postnet_model": cls.OV_POSTNET_MODEL_NAME,
"postnet_model_bin": cls.OV_POSTNET_MODEL_NAME.replace(".xml", ".bin"),
"vocoder_model": cls.OV_VOCODER_MODEL_NAME,
"vocoder_model_bin": cls.OV_VOCODER_MODEL_NAME.replace(".xml", ".bin"),
}
if os.path.isdir(model_id):
# Load model from a local directory
model_save_dir = Path(model_id)
file_names = {k: os.path.join(model_id, model_file_names[k]) for k in model_file_names}
else:
file_names = {}
for name, file_name in model_file_names.items():
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=file_name,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
file_names[name] = model_cache_path
model_save_dir = Path(model_cache_path).parent
if not compile_only:
encoder_model = OVBaseModel.load_model(file_names["encoder_model"])
decoder_model = OVBaseModel.load_model(file_names["decoder_model"])
postnet_model = OVBaseModel.load_model(file_names["postnet_model"])
vocoder_model = OVBaseModel.load_model(file_names["vocoder_model"])
else:
encoder_model = OVBaseModel._compile_model(
file_names["encoder_model"],
device,
ov_config,
model_save_dir,
)
decoder_model = OVBaseModel._compile_model(
file_names["decoder_model"],
device,
ov_config,
model_save_dir,
)
postnet_model = OVBaseModel._compile_model(
file_names["postnet_model"],
device,
ov_config,
model_save_dir,
)
vocoder_model = OVBaseModel._compile_model(
file_names["vocoder_model"],
device,
ov_config,
model_save_dir,
)
if generation_config is None:
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except Exception:
pass
quantization_config = OVBaseModel._prepare_quantization_config(quantization_config, load_in_8bit)
to_quantize = not compile_only and quantization_config is not None
if to_quantize:
enable_compilation = False
model = _OVModelForSpeechT5ForTextToSpeech(
encoder=encoder_model,
decoder=decoder_model,
postnet=postnet_model,
vocoder=vocoder_model,
config=config,
device=device,
dynamic_shapes=dynamic_shapes,
ov_config=ov_config,
model_save_dir=model_save_dir,
quantization_config=quantization_config,
preprocessors=preprocessors,
compile_only=compile_only,
compile=enable_compilation,
generation_config=generation_config,
)
if to_quantize:
from optimum.intel.openvino.quantization import OVQuantizer
quantization_config_copy = copy.deepcopy(quantization_config)
quantization_config_copy.tokenizer = quantization_config.tokenizer or model_id
OVQuantizer(model).quantize(ov_config=OVConfig(quantization_config=quantization_config_copy))
return model
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/speecht5/modeling_speecht5.py#L2464
# some decoder parts (prenet, wrapper_decoder, and feat_out) are combined into the single piece decoder
# Finally, we split the pipeline into four parts: encoder, decoder, postnet, and vocoder
def generate(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
speaker_embeddings: Optional[torch.FloatTensor] = None,
threshold: float = 0.5,
minlenratio: float = 0.0,
maxlenratio: float = 20.0,
vocoder: Optional[nn.Module] = None,
output_cross_attentions: bool = False,
return_output_lengths: bool = False,
**kwargs,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
if speaker_embeddings is None:
raise ValueError(
"""`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following
the code snippet provided in this link:
https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors
"""
)
input_values = input_ids
if attention_mask is None:
encoder_attention_mask = 1 - (input_values == self.config.pad_token_id).int()
else:
encoder_attention_mask = attention_mask
bsz = input_values.size(0)
encoder_out = self.encoder(input_values)
encoder_last_hidden_state = encoder_out.last_hidden_state
encoder_attention_mask = encoder_out.encoder_attention_mask
maxlen = int(encoder_last_hidden_state.size(1) * maxlenratio / self.config.reduction_factor)
minlen = int(encoder_last_hidden_state.size(1) * minlenratio / self.config.reduction_factor)
# Start the output sequence with a mel spectrum that is all zeros.
output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, self.config.num_mel_bins)
spectrogram = []
cross_attentions = []
idx = 0
result_spectrogram = {}
while True:
idx += 1
decoder_out = self.decoder(
inputs_embeds=output_sequence,
speaker_embeddings=speaker_embeddings,
encoder_last_hidden_state=encoder_last_hidden_state,
encoder_attention_mask=encoder_attention_mask,
)
spectrum = decoder_out.spectrum
spectrogram.append(spectrum)
output_sequence = decoder_out.output_sequence_out
prob = decoder_out.prob
if idx < minlen:
continue
else:
# If the generation loop is less than maximum length time, check the ones in the batch that have met
# the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch.
if idx < maxlen:
meet_thresholds = torch.sum(prob, dim=-1) >= threshold
meet_indexes = torch.where(meet_thresholds)[0].tolist()
else:
meet_indexes = range(len(prob))
meet_indexes = [i for i in meet_indexes if i not in result_spectrogram]
if len(meet_indexes) > 0:
spectrograms = torch.stack(spectrogram)
spectrograms = self.postnet(spectrograms)
spectrograms = spectrograms.postnet_spectrogram
for meet_index in meet_indexes:
result_spectrogram[meet_index] = spectrograms[meet_index]
if len(result_spectrogram) >= bsz:
break
spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))]
if not return_output_lengths:
spectrogram = (
spectrograms[0].unsqueeze(0)
if bsz == 1
else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
)
if self.vocoder is not None:
outputs = self.vocoder(spectrogram)
outputs = outputs.waveform
else:
outputs = spectrogram
if output_cross_attentions:
cross_attentions = torch.cat(cross_attentions, dim=2)
if bsz > 1:
cross_attentions = cross_attentions.view(
bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
)
outputs = (outputs, cross_attentions)
else:
# batched return values should also include the spectrogram/waveform lengths
spectrogram_lengths = []
for i in range(bsz):
spectrogram_lengths.append(spectrograms[i].size(0))
if vocoder is None:
spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
outputs = (spectrograms, spectrogram_lengths)
else:
waveforms = []
spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
waveforms = vocoder(spectrograms)
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
outputs = (waveforms, waveform_lengths)
if output_cross_attentions:
cross_attentions = torch.cat(cross_attentions, dim=2)
cross_attentions = cross_attentions.view(
bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
)
outputs = (*outputs, cross_attentions)
return outputs