core/speech_to_text.py (286 lines of code) (raw):
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Handles extracting speech from streaming audio content parts.
Uses Google Cloud Speech API to transcribe audio parts into text parts.
install google cloud speech client with:
```python
pip install --upgrade google-cloud-speech
```
See the `speech_to_text_cli.py` script for a usage example and how to test it
locally. It is recommended to test the quality of the transcription with
different models and recognizers.
"""
import asyncio
from collections.abc import AsyncIterable
import dataclasses
import time
from absl import logging
import dataclasses_json
from genai_processors import content_api
from genai_processors import processor
from genai_processors import streams
from google.cloud import speech_v2
import grpc
_SILENT_AUDIO_DELAY_SECONDS = 1
RecognizeStream = grpc.aio.StreamStreamCall[
speech_v2.types.StreamingRecognizeRequest,
speech_v2.types.StreamingRecognizeResponse,
]
DEFAULT_SAMPLE_RATE_HZ = 24000
# streaming_recognize RPC has limit on the duration and has to be restarted
# periodically. Instead of waiting for the deadline we try to restart it at
# the moments when that won't cause hiccups.
STREAMING_HARD_LIMIT_SEC = (
240 # 4 minutes / restart stream even when user is speaking.
)
STREAMING_LIMIT_SEC = (
180 # 3 minutes / restart stream when user is not speaking.
)
ProcessorPart = content_api.ProcessorPart
TRANSCRIPTION_SUBSTREAM_NAME = 'input_transcription'
ENDPOINTING_SUBSTREAM_NAME = 'input_endpointing'
@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class StartOfSpeech:
"""Start of speech event."""
pass
@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class EndOfSpeech:
"""End of speech event."""
pass
class AddSilentPartMaybe(processor.Processor):
"""Adds silent audio parts if no activity after `silent_part_duration_sec`.
If the stream is empty after a few seconds, the Speech API will close the
connection. This processor adds silent audio parts to the output stream to
keep the connection alive.
"""
def __init__(
self,
silent_part_duration_sec: float = 1,
sample_rate: int = DEFAULT_SAMPLE_RATE_HZ,
):
self._silent_part_duration_sec = silent_part_duration_sec
self._sample_rate = sample_rate
async def call(
self, content: AsyncIterable[ProcessorPart]
) -> AsyncIterable[ProcessorPart]:
logging.info('Transcriber: _process_audio started.')
last_streamed_audio_time_sec = time.perf_counter()
async def _insert_silent_audio() -> AsyncIterable[ProcessorPart]:
"""Sends silent audio to the Speech API to keep the stream alive."""
nonlocal last_streamed_audio_time_sec
while True:
await asyncio.sleep(self._silent_part_duration_sec)
delta_time_sec = time.perf_counter() - last_streamed_audio_time_sec
if delta_time_sec > self._silent_part_duration_sec:
yield ProcessorPart(
value=b'\0' * round(self._sample_rate * delta_time_sec),
mimetype=f'audio/l16; rate={self._sample_rate}',
)
last_streamed_audio_time_sec = time.perf_counter()
audio_stream = streams.merge(
[content, _insert_silent_audio()], stop_on_first=True
)
async for part in audio_stream:
last_streamed_audio_time_sec = time.perf_counter()
yield part
logging.info('Transcriber: _process_audio finished.')
class _Transcriber(processor.Processor):
"""Transcribes streaming audio using the Cloud Speech API."""
def __init__(
self,
project_id: str,
recognition_config: speech_v2.types.RecognitionConfig,
with_endpointing: bool = True,
substream_endpointing: str = ENDPOINTING_SUBSTREAM_NAME,
strict_endpointing: bool = True,
with_interim_results: bool = True,
substream_transcription: str = TRANSCRIPTION_SUBSTREAM_NAME,
passthrough_audio: bool = False,
):
"""Transcribes audio parts using the Cloud Speech API.
Args:
project_id: The project ID to use for the Speech API.
recognition_config: The recognition config to use for the Speech API. Set
it up to adjust the sample rate, languages or the recognition model.
with_endpointing: Whether to yield endpointing events. Endpointing events
are text parts with the value set to one of the
`speech_to_text.SpeechEventType` string enums. The endpointing events
are yielded in the substream defined by substream_endpointing.
substream_endpointing: The substream name to use for the endpointing
events.
strict_endpointing: Whether to send endpointing events only when interim
results have been found. This avoids yielding endpointing events when
the user speech is not recognized (e.g. does not return endpointing for
noise or laughs or coughing, etc.).
with_interim_results: Whether to yield interim results. If set to False,
the processor will only yield the final transcription.
substream_transcription: The substream name to use for the transcription.
passthrough_audio: Whether to passthrough the audio parts to the output
stream. The substream name is set to the default one: ''.
"""
self._config = speech_v2.types.StreamingRecognitionConfig(
config=recognition_config,
streaming_features=speech_v2.types.StreamingRecognitionFeatures(
interim_results=True,
enable_voice_activity_events=True,
),
)
self._sample_rate = (
self._config.config.explicit_decoding_config.sample_rate_hertz
or DEFAULT_SAMPLE_RATE_HZ
)
self._with_endpointing = with_endpointing
self._substream_endpointing = substream_endpointing
self._strict_endpointing = strict_endpointing
self._with_interim_results = with_interim_results
self._substream_transcription = substream_transcription
self._project_id = project_id
self._passthrough_audio = passthrough_audio
def _make_setup_request(self) -> speech_v2.types.StreamingRecognizeRequest:
return speech_v2.types.StreamingRecognizeRequest(
streaming_config=self._config,
recognizer=(
f'projects/{self._project_id}/locations/global/recognizers/_'
),
)
async def call(
self,
content: AsyncIterable[ProcessorPart],
) -> AsyncIterable[ProcessorPart]:
"""Transcribes streaming audio using the Cloud Speech API."""
# The output queue is used to yield the audio parts unchanged in the output
# stream when self._passthrough_audio is True.
output_queue = asyncio.Queue[ProcessorPart | None]()
stream_state: dict[str, bool | float] = {
'start_time_sec': time.perf_counter(),
'restart_stream': False,
'user_speaking': False,
'stream_is_on': True,
}
async def request_stream(
request_queue: asyncio.Queue[
speech_v2.types.StreamingRecognizeRequest | None
],
):
try:
request_queue.put_nowait(self._make_setup_request())
async for part in content:
if not content_api.is_audio(part.mimetype):
output_queue.put_nowait(part)
continue
if self._passthrough_audio:
output_queue.put_nowait(part)
if part.part.inline_data is None:
continue
if not part.mimetype.lower().startswith(
'audio/l16'
) or not part.mimetype.lower().endswith(f'rate={self._sample_rate}'):
raise ValueError(
f'Unsupported audio mimetype: {part.mimetype}. Expected'
f' audio/l16;[.*]rate={self._sample_rate}.'
)
request_queue.put_nowait(
speech_v2.types.StreamingRecognizeRequest(
audio=part.part.inline_data.data,
)
)
delta_time_sec = time.perf_counter() - stream_state['start_time_sec']
if (
(delta_time_sec > STREAMING_LIMIT_SEC)
and not stream_state['user_speaking']
) or (delta_time_sec > STREAMING_HARD_LIMIT_SEC):
stream_state['restart_stream'] = True
break
finally:
request_queue.put_nowait(None)
async def send_audio_to_speech_api():
# Instantiates a client.
try:
logging.debug('Transcriber: (re)creating client')
client = speech_v2.SpeechAsyncClient()
last_endpointing_event = None
while stream_state['stream_is_on']:
request_queue = asyncio.Queue[
speech_v2.types.StreamingRecognizeRequest | None
]()
populate_request_queue = processor.create_task(
request_stream(request_queue)
)
response_stream = await client.streaming_recognize(
requests=streams.dequeue(request_queue)
)
async for response in response_stream:
if response == grpc.aio.EOF:
break
if (
response.speech_event_type
== speech_v2.types.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
):
last_endpointing_event = StartOfSpeech()
stream_state['user_speaking'] = True
if self._with_endpointing and not self._strict_endpointing:
last_endpointing_event = 'SPEECH_ACTIVITY_BEGIN_SENT'
output_queue.put_nowait(
ProcessorPart.from_dataclass(
dataclass=StartOfSpeech(),
substream_name=self._substream_endpointing,
)
)
if response.results and response.results[0].alternatives:
if (
isinstance(last_endpointing_event, StartOfSpeech)
and self._strict_endpointing
):
# We have not sent the SPEECH_ACTIVITY_BEGIN event yet, we
# waited for the first transcript to appear.
last_endpointing_event = 'SPEECH_ACTIVITY_BEGIN_SENT'
output_queue.put_nowait(
ProcessorPart.from_dataclass(
dataclass=StartOfSpeech(),
substream_name=self._substream_endpointing,
)
)
if text := response.results[0].alternatives[0].transcript:
metadata = {
'is_final': response.results[0].is_final,
}
if self._with_interim_results or response.results[0].is_final:
output_queue.put_nowait(
ProcessorPart(
text,
role='user',
metadata=metadata,
substream_name=self._substream_transcription,
)
)
if (
response.speech_event_type
== speech_v2.types.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
):
stream_state['user_speaking'] = False
if (
self._with_endpointing
and last_endpointing_event == 'SPEECH_ACTIVITY_BEGIN_SENT'
):
output_queue.put_nowait(
ProcessorPart.from_dataclass(
dataclass=EndOfSpeech(),
substream_name=self._substream_endpointing,
)
)
last_endpointing_event = None
if stream_state['restart_stream']:
stream_state['restart_stream'] = False
stream_state['stream_is_on'] = True
stream_state['start_time_sec'] = time.perf_counter()
client = speech_v2.SpeechAsyncClient()
populate_request_queue.cancel()
else:
stream_state['stream_is_on'] = False
finally:
output_queue.put_nowait(None)
send_task = processor.create_task(send_audio_to_speech_api())
while part := await output_queue.get():
yield part
await send_task
class SpeechToText(processor.Processor):
"""Converts audio parts into text parts."""
def __init__(
self,
project_id: str,
recognition_config: speech_v2.types.RecognitionConfig | None = None,
audio_passthrough: bool = False,
with_endpointing: bool = True,
substream_endpointing: str = ENDPOINTING_SUBSTREAM_NAME,
strict_endpointing: bool = True,
with_interim_results: bool = True,
substream_transcription: str = TRANSCRIPTION_SUBSTREAM_NAME,
maintain_connection_active_with_silent_audio: bool = False,
):
"""Initializes the SpeechToText processor.
The speech processor uses the Cloud Speech API to transcribe audio parts
into text parts. It injects silent audio parts to keep the stream alive
when the user is not speaking and restarts the connection automatically
after 3-4 minutes to avoid the stream being closed by the server.
The processor yields endpointing events when the user starts and stops
speaking. If with_endpointing is False, the endpointing events are not
yielded. The endpointing events are yielded in the substream defined by
substream_endpointing. When strict_endpointing is True, the endpointing
events are yielded only when interim results have been found. This avoids
yielding endpointing events when the user speech is not recognized (e.g.
short noise or sound).
Args:
project_id: The project ID to use for the Speech API.
recognition_config: The recognition config to use for the Speech API. Set
it up to adjust the sample rate, languages or the recognition model.
audio_passthrough: Whether to passthrough the audio parts to the output
stream. The substream name is set to the default one: ''.
with_endpointing: Whether to yield endpointing events. Endpointing events
are text parts with the value set to one of the
`speech_to_text.SpeechEventType` string enums. The endpointing events
are yielded in the substream defined by substream_endpointing.
substream_endpointing: The substream name to use for the endpointing
events.
strict_endpointing: Whether to send endpointing events only when interim
results have been found. This avoids yielding endpointing events when
the user speech is not recognized (e.g. does not return endpointing for
noise or laughs or coughing, etc.).
with_interim_results: Whether to yield interim results. If set to False,
the processor will only yield the final transcription.
substream_transcription: The substream name to use for the transcription.
maintain_connection_active_with_silent_audio: Whether to maintain the
connection active with silent audio. If set to True, the processor will
inject silent audio parts to keep the stream alive when the processor
does not receive any audio part. This can be needed if the Speech API
closes the stream when it does not receive any audio for a long time.
"""
recognition_config = recognition_config or speech_v2.types.RecognitionConfig(
explicit_decoding_config=speech_v2.types.ExplicitDecodingConfig(
sample_rate_hertz=DEFAULT_SAMPLE_RATE_HZ,
encoding=speech_v2.types.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
audio_channel_count=1,
),
language_codes=['en-US'],
model='latest_long',
)
self._processor = _Transcriber(
project_id=project_id,
recognition_config=recognition_config,
with_endpointing=with_endpointing,
substream_endpointing=substream_endpointing,
strict_endpointing=strict_endpointing,
with_interim_results=with_interim_results,
substream_transcription=substream_transcription,
passthrough_audio=audio_passthrough,
)
if maintain_connection_active_with_silent_audio:
sample_rate = (
recognition_config.explicit_decoding_config.sample_rate_hertz
or DEFAULT_SAMPLE_RATE_HZ
)
self._processor = (
AddSilentPartMaybe(
silent_part_duration_sec=_SILENT_AUDIO_DELAY_SECONDS,
sample_rate=sample_rate,
)
+ self._processor
)
async def call(
self,
content: AsyncIterable[ProcessorPart],
) -> AsyncIterable[ProcessorPart]:
async for part in self._processor(content):
yield part