pyrit/prompt_converter/azure_speech_text_to_audio_converter.py (85 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import logging from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: import azure.cognitiveservices.speech as speechsdk # noqa: F401 from pyrit.common import default_values from pyrit.models import PromptDataType, data_serializer_factory from pyrit.prompt_converter import ConverterResult, PromptConverter logger = logging.getLogger(__name__) class AzureSpeechTextToAudioConverter(PromptConverter): """ The AzureSpeechTextToAudio takes a prompt and generates a wave file. https://learn.microsoft.com/en-us/azure/ai-services/speech-service/text-to-speech Args: azure_speech_region (str): The name of the Azure region. azure_speech_key (str): The API key for accessing the service. synthesis_language (str): Synthesis voice language synthesis_voice_name (str): Synthesis voice name, see URL For more details see the following link for synthesis language and synthesis voice: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support filename (str): File name to be generated. Please include either .wav or .mp3 output_format (str): Either wav or mp3. Must match the file prefix. """ AZURE_SPEECH_REGION_ENVIRONMENT_VARIABLE: str = "AZURE_SPEECH_REGION" AZURE_SPEECH_KEY_ENVIRONMENT_VARIABLE: str = "AZURE_SPEECH_KEY" AzureSpeachAudioFormat = Literal["wav", "mp3"] def __init__( self, azure_speech_region: str = None, azure_speech_key: str = None, synthesis_language: str = "en_US", synthesis_voice_name: str = "en-US-AvaNeural", output_format: AzureSpeachAudioFormat = "wav", ) -> None: self._azure_speech_region: str = default_values.get_required_value( env_var_name=self.AZURE_SPEECH_REGION_ENVIRONMENT_VARIABLE, passed_value=azure_speech_region ) self._azure_speech_key: str = default_values.get_required_value( env_var_name=self.AZURE_SPEECH_KEY_ENVIRONMENT_VARIABLE, passed_value=azure_speech_key ) self._synthesis_language = synthesis_language self._synthesis_voice_name = synthesis_voice_name self._output_format = output_format def input_supported(self, input_type: PromptDataType) -> bool: return input_type == "text" def output_supported(self, output_type: PromptDataType) -> bool: return output_type == "audio_path" async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: try: import azure.cognitiveservices.speech as speechsdk # noqa: F811 except ModuleNotFoundError as e: logger.error( "Could not import azure.cognitiveservices.speech. " + "You may need to install it via 'pip install pyrit[speech]'" ) raise e if not self.input_supported(input_type): raise ValueError("Input type not supported") if prompt.strip() == "": raise ValueError("Prompt was empty. Please provide valid input prompt.") audio_serializer = data_serializer_factory( category="prompt-memory-entries", data_type="audio_path", extension=self._output_format ) audio_serializer_file = None try: speech_config = speechsdk.SpeechConfig( subscription=self._azure_speech_key, region=self._azure_speech_region, ) speech_config.speech_synthesis_language = self._synthesis_language speech_config.speech_synthesis_voice_name = self._synthesis_voice_name if self._output_format == "mp3": speech_config.set_speech_synthesis_output_format( speechsdk.SpeechSynthesisOutputFormat.Audio16Khz32KBitRateMonoMp3 ) speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=speech_config) result = speech_synthesizer.speak_text_async(prompt).get() if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: audio_data = result.audio_data await audio_serializer.save_data(audio_data) audio_serializer_file = str(audio_serializer.value) logger.info( "Speech synthesized for text [{}], and the audio was saved to [{}]".format( prompt, audio_serializer_file ) ) elif result.reason == speechsdk.ResultReason.Canceled: cancellation_details = result.cancellation_details logger.info("Speech synthesis canceled: {}".format(cancellation_details.reason)) if cancellation_details.reason == speechsdk.CancellationReason.Error: logger.error("Error details: {}".format(cancellation_details.error_details)) raise RuntimeError( "Speech synthesis canceled: {}".format(cancellation_details.reason) + "Error details: {}".format(cancellation_details.error_details) ) except Exception as e: logger.error("Failed to convert prompt to audio: %s", str(e)) raise return ConverterResult(output_text=audio_serializer_file, output_type="audio_path")