projects/speech2speech_translation/s2s_common.py (173 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common functions for speech2speech clients.""" import datetime import re from google.api_core.client_options import ClientOptions import google.api_core.exceptions from google.cloud import storage from google.cloud import texttospeech from google.cloud import translate_v2 as translate from google.cloud.speech_v2 import SpeechClient from google.cloud.speech_v2.types import cloud_speech def parse_gcs_url(url): """Parses a GCS URL and extracts the bucket name and path. Args: url: The GCS URL to parse. Returns: A dictionary containing the following keys: - bucket: The name of the bucket. - path: The path to the object within the bucket. Raises: ValueError: If the URL is not in the expected format. """ match = re.match(r"^gs://(?P<bucket>[^/]+)/(?P<path>.*)$", url) if match: return match.groupdict() else: raise ValueError("Invalid GCS URL: {}".format(url)) def upload_file_to_gcs(project_id, gcs, filepath, filename, logger): """Uploads a local file to Google Cloud Storage. This function takes a local file path and filename, GCS bucket and path, and uploads the file to the specified GCS location. Args: project_id: Google Cloud project ID. gcs: A dictionary containing the GCS bucket name and path. filepath: The path to the local file to upload. filename: The name of the file to upload. logger: Logging object. Returns: The Google Cloud Storage URI of the uploaded file. """ storage_client = storage.Client(project=project_id) bucket = storage_client.get_bucket(gcs["bucket"]) gcs_uri_input_audio = "gs://" + gcs["bucket"] + "/" + gcs["path"] + filename blob = bucket.blob(gcs["path"] + filename) logger.info("Uploading %s to %s", filepath + filename, gcs_uri_input_audio) blob.upload_from_filename(filepath + filename) return gcs_uri_input_audio def upload_variable_to_gcs(project_id, gcs, object_name, data, content_type): """Uploads data to Google Cloud Storage from memory. Args: project_id: Google Cloud project ID. gcs: A dictionary containing the GCS bucket name and path. object_name: The name of the object to upload. data: The data to upload (bytes or string). content_type: The content type of the data. Returns: The Google Cloud Storage URI of the uploaded file. """ storage_client = storage.Client(project=project_id) bucket = storage_client.get_bucket(gcs["bucket"]) blob = bucket.blob(gcs["path"] + object_name) if isinstance(data, str): data = data.encode("utf-8") blob.upload_from_string(data, content_type=content_type) gcs_uri = "gs://" + gcs["bucket"] + "/" + gcs["path"] + object_name return gcs_uri def speech_to_text( project_id, location, source_language_code, stt_model, stt_timeout, gcs_uri_input_audio, logger, ): """Transcribes speech audio to text. Args: project_id: Google Cloud project ID. location: Google Cloud location to use. source_language_code: Source language code for STT. stt_model: Speech to Text model to use. stt_timeout: Timeout for STT operation in seconds. gcs_uri_input_audio: Google Cloud Storage URI of the input audio file. logger: Logger object. Returns: Speech-to-Text response object. """ client = SpeechClient( client_options=ClientOptions( api_endpoint=f"{location}-speech.googleapis.com", ) ) recognition_features = cloud_speech.RecognitionFeatures( enable_automatic_punctuation=True, enable_spoken_punctuation=True, enable_word_time_offsets=True, profanity_filter=False, max_alternatives=1, ) recognition_config = cloud_speech.RecognitionConfig( auto_decoding_config=cloud_speech.AutoDetectDecodingConfig(), language_codes=[source_language_code], model=stt_model, features=recognition_features, ) metadata = cloud_speech.BatchRecognizeFileMetadata(uri=gcs_uri_input_audio) request = cloud_speech.BatchRecognizeRequest( recognizer=( f"projects/{project_id}/locations/" f"{location}/recognizers/_" ), config=recognition_config, files=[metadata], recognition_output_config=cloud_speech.RecognitionOutputConfig( inline_response_config=cloud_speech.InlineOutputConfig(), ), ) operation = client.batch_recognize(request=request) logger.info( "Waiting for STT operation to complete... timeout: %ss", stt_timeout ) response = operation.result(timeout=stt_timeout) return response def parse_stt_response(stt_response, uri, alt, logger): """Parses Speech-to-Text response and chooses alternative per sentence. Args: stt_response: Speech-to-Text response object. uri: Google Cloud Storage URI of the input audio file. alt: Alternative to choose from the response. logger: Logger object. Returns: Transcript string. """ transcript = [] for result in stt_response.results[uri].transcript.results: transcript.append(result.alternatives[alt].transcript) logger.debug("STT Transcript for alternative %s: %s", alt, transcript) if isinstance(transcript, bytes): transcript = transcript.decode("utf-8") return "".join(transcript) def translate_text(target_language, transcript, logger): """Translates text to target language. Args: target_language: Target language. transcript: Transcript string. logger: Logger object. Returns: Translate response object. """ translate_client = translate.Client() logger.info("Translating text to: %s", target_language) translate_result = translate_client.translate( transcript, target_language=target_language ) return translate_result def text_to_speech( project_id, location, target_voice, target_voice_gender, target_language_code, text, tts_timeout, gcs, output_audio_file_name, logger, prefix, ): """Converts text to speech. Args: project_id: Google Cloud project ID. location: Google Cloud location to use. target_voice: TTS voice to use. target_voice_gender: Gender to use with TTS voice. target_language_code: Target language code. text: Text to convert to speech. tts_timeout: Timeout for TTS operation in seconds. gcs: A dictionary containing the GCS bucket name and path. output_audio_file_name: Output audio file name. logger: Logger object. prefix: Optional file name prefix. """ client = texttospeech.TextToSpeechLongAudioSynthesizeClient() input_text = texttospeech.SynthesisInput(text=text) gender = ( texttospeech.SsmlVoiceGender.FEMALE if target_voice_gender == "female" else texttospeech.SsmlVoiceGender.MALE ) voice = texttospeech.VoiceSelectionParams( language_code=target_language_code, name=target_voice, ssml_gender=gender, ) audio_config = texttospeech.AudioConfig( audio_encoding=texttospeech.AudioEncoding.LINEAR16, speaking_rate=1.0, ) parent = f"projects/{project_id}/locations/{location}" gcs_path = "gs://" + gcs["bucket"] + "/" + gcs["path"] request = texttospeech.SynthesizeLongAudioRequest( parent=parent, input=input_text, audio_config=audio_config, voice=voice, output_gcs_uri=(f"{gcs_path}{prefix}{output_audio_file_name}"), ) logger.info( "Waiting for TTS operation to complete... timeout: %ss", tts_timeout ) try: operation = client.synthesize_long_audio(request=request) result = operation.result(timeout=tts_timeout) logger.info( "TTS output written to: %s", f"{gcs_path}{prefix}{output_audio_file_name}", ) if result: logger.info("TTS operation result: %s", result) except google.api_core.exceptions.GoogleAPICallError as e: logger.error("TTS operation failed: %s. Exiting", e) exit(1) def list_tts_voices(logger): """Lists available voices for text-to-speech. Args: logger: Logging object. """ client = texttospeech.TextToSpeechClient() voices = client.list_voices() output = [] for voice in voices.voices: output.append(f"\nName: {voice.name}\n") for language_code in voice.language_codes: output.append(f"Supported language: {language_code}\n") ssml_gender = texttospeech.SsmlVoiceGender(voice.ssml_gender) output.append(f"SSML Voice Gender: {ssml_gender.name}\n") output.append( f"Natural Sample Rate Hz: {voice.natural_sample_rate_hertz}\n" ) logger.info("".join(output)) def list_translate_languages(logger): """Lists available languages for translation. Args: logger: Logging object. """ translate_client = translate.Client() results = translate_client.get_languages() output = [] output.append("\n") for language in results: output.append("Name: " + language["name"] + "\n") output.append("Language: " + language["language"] + "\n") logger.info("".join(output)) def generate_filename_prefix(filename_prefix): """Generates filename prefix based on config. Args: filename_prefix: String of preferred prefix type. Returns: Filename prefix. """ if "timestamp" not in filename_prefix: return None now = datetime.datetime.now() prefix = now.isoformat() + "_" return prefix