# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Handles extracting speech from streaming audio content parts.

Uses Google Cloud Speech API to transcribe audio parts into text parts.

install google cloud speech client with:

```python
pip install --upgrade google-cloud-speech
```

See the `speech_to_text_cli.py` script for a usage example and how to test it
locally. It is recommended to test the quality of the transcription with
different models and recognizers.
"""

import asyncio
from collections.abc import AsyncIterable
import dataclasses
import time

from absl import logging
import dataclasses_json
from genai_processors import content_api
from genai_processors import processor
from genai_processors import streams
from google.cloud import speech_v2
import grpc


_SILENT_AUDIO_DELAY_SECONDS = 1

RecognizeStream = grpc.aio.StreamStreamCall[
    speech_v2.types.StreamingRecognizeRequest,
    speech_v2.types.StreamingRecognizeResponse,
]


DEFAULT_SAMPLE_RATE_HZ = 24000
# streaming_recognize RPC has limit on the duration and has to be restarted
# periodically. Instead of waiting for the deadline we try to restart it at
# the moments when that won't cause hiccups.
STREAMING_HARD_LIMIT_SEC = (
    240  # 4 minutes / restart stream even when user is speaking.
)
STREAMING_LIMIT_SEC = (
    180  # 3 minutes / restart stream when user is not speaking.
)

ProcessorPart = content_api.ProcessorPart

TRANSCRIPTION_SUBSTREAM_NAME = 'input_transcription'
ENDPOINTING_SUBSTREAM_NAME = 'input_endpointing'


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class StartOfSpeech:
  """Start of speech event."""

  pass


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class EndOfSpeech:
  """End of speech event."""

  pass


class AddSilentPartMaybe(processor.Processor):
  """Adds silent audio parts if no activity after `silent_part_duration_sec`.

  If the stream is empty after a few seconds, the Speech API will close the
  connection. This processor adds silent audio parts to the output stream to
  keep the connection alive.
  """

  def __init__(
      self,
      silent_part_duration_sec: float = 1,
      sample_rate: int = DEFAULT_SAMPLE_RATE_HZ,
  ):
    self._silent_part_duration_sec = silent_part_duration_sec
    self._sample_rate = sample_rate

  async def call(
      self, content: AsyncIterable[ProcessorPart]
  ) -> AsyncIterable[ProcessorPart]:

    logging.info('Transcriber: _process_audio started.')
    last_streamed_audio_time_sec = time.perf_counter()

    async def _insert_silent_audio() -> AsyncIterable[ProcessorPart]:
      """Sends silent audio to the Speech API to keep the stream alive."""
      nonlocal last_streamed_audio_time_sec

      while True:
        await asyncio.sleep(self._silent_part_duration_sec)
        delta_time_sec = time.perf_counter() - last_streamed_audio_time_sec
        if delta_time_sec > self._silent_part_duration_sec:
          yield ProcessorPart(
              value=b'\0' * round(self._sample_rate * delta_time_sec),
              mimetype=f'audio/l16; rate={self._sample_rate}',
          )
          last_streamed_audio_time_sec = time.perf_counter()

    audio_stream = streams.merge(
        [content, _insert_silent_audio()], stop_on_first=True
    )

    async for part in audio_stream:
      last_streamed_audio_time_sec = time.perf_counter()
      yield part

    logging.info('Transcriber: _process_audio finished.')


class _Transcriber(processor.Processor):
  """Transcribes streaming audio using the Cloud Speech API."""

  def __init__(
      self,
      project_id: str,
      recognition_config: speech_v2.types.RecognitionConfig,
      with_endpointing: bool = True,
      substream_endpointing: str = ENDPOINTING_SUBSTREAM_NAME,
      strict_endpointing: bool = True,
      with_interim_results: bool = True,
      substream_transcription: str = TRANSCRIPTION_SUBSTREAM_NAME,
      passthrough_audio: bool = False,
  ):
    """Transcribes audio parts using the Cloud Speech API.

    Args:
      project_id: The project ID to use for the Speech API.
      recognition_config: The recognition config to use for the Speech API. Set
        it up to adjust the sample rate, languages or the recognition model.
      with_endpointing: Whether to yield endpointing events. Endpointing events
        are text parts with the value set to one of the
        `speech_to_text.SpeechEventType` string enums. The endpointing events
        are yielded in the substream defined by substream_endpointing.
      substream_endpointing: The substream name to use for the endpointing
        events.
      strict_endpointing: Whether to send endpointing events only when interim
        results have been found. This avoids yielding endpointing events when
        the user speech is not recognized (e.g. does not return endpointing for
        noise or laughs or coughing, etc.).
      with_interim_results: Whether to yield interim results. If set to False,
        the processor will only yield the final transcription.
      substream_transcription: The substream name to use for the transcription.
      passthrough_audio: Whether to passthrough the audio parts to the output
        stream. The substream name is set to the default one: ''.
    """
    self._config = speech_v2.types.StreamingRecognitionConfig(
        config=recognition_config,
        streaming_features=speech_v2.types.StreamingRecognitionFeatures(
            interim_results=True,
            enable_voice_activity_events=True,
        ),
    )
    self._sample_rate = (
        self._config.config.explicit_decoding_config.sample_rate_hertz
        or DEFAULT_SAMPLE_RATE_HZ
    )
    self._with_endpointing = with_endpointing
    self._substream_endpointing = substream_endpointing
    self._strict_endpointing = strict_endpointing
    self._with_interim_results = with_interim_results
    self._substream_transcription = substream_transcription
    self._project_id = project_id
    self._passthrough_audio = passthrough_audio

  def _make_setup_request(self) -> speech_v2.types.StreamingRecognizeRequest:
    return speech_v2.types.StreamingRecognizeRequest(
        streaming_config=self._config,
        recognizer=(
            f'projects/{self._project_id}/locations/global/recognizers/_'
        ),
    )

  async def call(
      self,
      content: AsyncIterable[ProcessorPart],
  ) -> AsyncIterable[ProcessorPart]:
    """Transcribes streaming audio using the Cloud Speech API."""

    # The output queue is used to yield the audio parts unchanged in the output
    # stream when self._passthrough_audio is True.
    output_queue = asyncio.Queue[ProcessorPart | None]()

    stream_state: dict[str, bool | float] = {
        'start_time_sec': time.perf_counter(),
        'restart_stream': False,
        'user_speaking': False,
        'stream_is_on': True,
    }

    async def request_stream(
        request_queue: asyncio.Queue[
            speech_v2.types.StreamingRecognizeRequest | None
        ],
    ):
      try:
        request_queue.put_nowait(self._make_setup_request())
        async for part in content:
          if not content_api.is_audio(part.mimetype):
            output_queue.put_nowait(part)
            continue
          if self._passthrough_audio:
            output_queue.put_nowait(part)
          if part.part.inline_data is None:
            continue
          if not part.mimetype.lower().startswith(
              'audio/l16'
          ) or not part.mimetype.lower().endswith(f'rate={self._sample_rate}'):
            raise ValueError(
                f'Unsupported audio mimetype: {part.mimetype}. Expected'
                f' audio/l16;[.*]rate={self._sample_rate}.'
            )
          request_queue.put_nowait(
              speech_v2.types.StreamingRecognizeRequest(
                  audio=part.part.inline_data.data,
              )
          )
          delta_time_sec = time.perf_counter() - stream_state['start_time_sec']
          if (
              (delta_time_sec > STREAMING_LIMIT_SEC)
              and not stream_state['user_speaking']
          ) or (delta_time_sec > STREAMING_HARD_LIMIT_SEC):
            stream_state['restart_stream'] = True
            break
      finally:
        request_queue.put_nowait(None)

    async def send_audio_to_speech_api():
      # Instantiates a client.
      try:
        logging.debug('Transcriber: (re)creating client')
        client = speech_v2.SpeechAsyncClient()
        last_endpointing_event = None
        while stream_state['stream_is_on']:
          request_queue = asyncio.Queue[
              speech_v2.types.StreamingRecognizeRequest | None
          ]()
          populate_request_queue = processor.create_task(
              request_stream(request_queue)
          )
          response_stream = await client.streaming_recognize(
              requests=streams.dequeue(request_queue)
          )
          async for response in response_stream:
            if response == grpc.aio.EOF:
              break
            if (
                response.speech_event_type
                == speech_v2.types.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
            ):
              last_endpointing_event = StartOfSpeech()
              stream_state['user_speaking'] = True
              if self._with_endpointing and not self._strict_endpointing:
                last_endpointing_event = 'SPEECH_ACTIVITY_BEGIN_SENT'
                output_queue.put_nowait(
                    ProcessorPart.from_dataclass(
                        dataclass=StartOfSpeech(),
                        substream_name=self._substream_endpointing,
                    )
                )
            if response.results and response.results[0].alternatives:
              if (
                  isinstance(last_endpointing_event, StartOfSpeech)
                  and self._strict_endpointing
              ):
                # We have not sent the SPEECH_ACTIVITY_BEGIN event yet, we
                # waited for the first transcript to appear.
                last_endpointing_event = 'SPEECH_ACTIVITY_BEGIN_SENT'
                output_queue.put_nowait(
                    ProcessorPart.from_dataclass(
                        dataclass=StartOfSpeech(),
                        substream_name=self._substream_endpointing,
                    )
                )
              if text := response.results[0].alternatives[0].transcript:
                metadata = {
                    'is_final': response.results[0].is_final,
                }
                if self._with_interim_results or response.results[0].is_final:
                  output_queue.put_nowait(
                      ProcessorPart(
                          text,
                          role='user',
                          metadata=metadata,
                          substream_name=self._substream_transcription,
                      )
                  )
            if (
                response.speech_event_type
                == speech_v2.types.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
            ):
              stream_state['user_speaking'] = False
              if (
                  self._with_endpointing
                  and last_endpointing_event == 'SPEECH_ACTIVITY_BEGIN_SENT'
              ):
                output_queue.put_nowait(
                    ProcessorPart.from_dataclass(
                        dataclass=EndOfSpeech(),
                        substream_name=self._substream_endpointing,
                    )
                )
              last_endpointing_event = None
          if stream_state['restart_stream']:
            stream_state['restart_stream'] = False
            stream_state['stream_is_on'] = True
            stream_state['start_time_sec'] = time.perf_counter()
            client = speech_v2.SpeechAsyncClient()
            populate_request_queue.cancel()
          else:
            stream_state['stream_is_on'] = False
      finally:
        output_queue.put_nowait(None)

    send_task = processor.create_task(send_audio_to_speech_api())
    while part := await output_queue.get():
      yield part
    await send_task


class SpeechToText(processor.Processor):
  """Converts audio parts into text parts."""

  def __init__(
      self,
      project_id: str,
      recognition_config: speech_v2.types.RecognitionConfig | None = None,
      audio_passthrough: bool = False,
      with_endpointing: bool = True,
      substream_endpointing: str = ENDPOINTING_SUBSTREAM_NAME,
      strict_endpointing: bool = True,
      with_interim_results: bool = True,
      substream_transcription: str = TRANSCRIPTION_SUBSTREAM_NAME,
      maintain_connection_active_with_silent_audio: bool = False,
  ):
    """Initializes the SpeechToText processor.

    The speech processor uses the Cloud Speech API to transcribe audio parts
    into text parts. It injects silent audio parts to keep the stream alive
    when the user is not speaking and restarts the connection automatically
    after 3-4 minutes to avoid the stream being closed by the server.

    The processor yields endpointing events when the user starts and stops
    speaking. If with_endpointing is False, the endpointing events are not
    yielded. The endpointing events are yielded in the substream defined by
    substream_endpointing. When strict_endpointing is True, the endpointing
    events are yielded only when interim results have been found. This avoids
    yielding endpointing events when the user speech is not recognized (e.g.
     short noise or sound).

    Args:
      project_id: The project ID to use for the Speech API.
      recognition_config: The recognition config to use for the Speech API. Set
        it up to adjust the sample rate, languages or the recognition model.
      audio_passthrough: Whether to passthrough the audio parts to the output
        stream. The substream name is set to the default one: ''.
      with_endpointing: Whether to yield endpointing events. Endpointing events
        are text parts with the value set to one of the
        `speech_to_text.SpeechEventType` string enums. The endpointing events
        are yielded in the substream defined by substream_endpointing.
      substream_endpointing: The substream name to use for the endpointing
        events.
      strict_endpointing: Whether to send endpointing events only when interim
        results have been found. This avoids yielding endpointing events when
        the user speech is not recognized (e.g. does not return endpointing for
        noise or laughs or coughing, etc.).
      with_interim_results: Whether to yield interim results. If set to False,
        the processor will only yield the final transcription.
      substream_transcription: The substream name to use for the transcription.
      maintain_connection_active_with_silent_audio: Whether to maintain the
        connection active with silent audio. If set to True, the processor will
        inject silent audio parts to keep the stream alive when the processor
        does not receive any audio part. This can be needed if the Speech API
        closes the stream when it does not receive any audio for a long time.
    """
    recognition_config = recognition_config or speech_v2.types.RecognitionConfig(
        explicit_decoding_config=speech_v2.types.ExplicitDecodingConfig(
            sample_rate_hertz=DEFAULT_SAMPLE_RATE_HZ,
            encoding=speech_v2.types.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
            audio_channel_count=1,
        ),
        language_codes=['en-US'],
        model='latest_long',
    )
    self._processor = _Transcriber(
        project_id=project_id,
        recognition_config=recognition_config,
        with_endpointing=with_endpointing,
        substream_endpointing=substream_endpointing,
        strict_endpointing=strict_endpointing,
        with_interim_results=with_interim_results,
        substream_transcription=substream_transcription,
        passthrough_audio=audio_passthrough,
    )
    if maintain_connection_active_with_silent_audio:
      sample_rate = (
          recognition_config.explicit_decoding_config.sample_rate_hertz
          or DEFAULT_SAMPLE_RATE_HZ
      )
      self._processor = (
          AddSilentPartMaybe(
              silent_part_duration_sec=_SILENT_AUDIO_DELAY_SECONDS,
              sample_rate=sample_rate,
          )
          + self._processor
      )

  async def call(
      self,
      content: AsyncIterable[ProcessorPart],
  ) -> AsyncIterable[ProcessorPart]:
    async for part in self._processor(content):
      yield part
