# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Streams audio content parts from text parts.

Uses Google Text-To-Speech API to generate audio parts from text parts.

Install google cloud Text-To-Speech speech client with:

```python
pip install --upgrade google-cloud-texttospeech
```

See the `text_to_speech_cli.py` script for a usage example and how to test it
locally.
"""

import asyncio
from collections.abc import AsyncIterable

from genai_processors import content_api
from genai_processors import processor
from genai_processors import streams
from google.cloud import texttospeech_v1 as texttospeech

ProcessorPart = content_api.ProcessorPart


class TextToSpeech(processor.Processor):
  """Streams audio content parts from text parts."""

  def __init__(
      self,
      project_id: str,
      language_code: str = 'en-US',
      voice_name: str = 'en-US-Chirp3-HD-Charon',
      with_text_passthrough: bool = True,
  ):
    """Initializes the TextToSpeech processor.

    This processor uses the Google Text-To-Speech API to generate audio parts
    from text parts. The audio parts are yielded as `ProcessorPart` objects with
    the `audio_content` attribute set to the audio bytes and the `mimetype`
    attribute set to `audio/l16;rate=24000`.

    Args:
      project_id: The project ID to use for the Text-To-Speech API.
      language_code: The language code to use for the Text-To-Speech API.
      voice_name: The voice name to use for the Text-To-Speech API. See list of
        voices here: https://cloud.google.com/text-to-speech/docs/chirp3-hd
      with_text_passthrough: Whether to passthrough the text parts to the output
        stream. When set to True, the text parts are yielded back.
    """
    streaming_config = texttospeech.StreamingSynthesizeConfig(
        streaming_audio_config=texttospeech.StreamingAudioConfig(
            audio_encoding=texttospeech.AudioEncoding.PCM,
            sample_rate_hertz=24000,
        ),
        voice=texttospeech.VoiceSelectionParams(
            name=voice_name,
            language_code=language_code,
        ),
    )
    self._config_request = texttospeech.StreamingSynthesizeRequest(
        streaming_config=streaming_config
    )
    self._project_id = project_id
    self._with_text_passthrough = with_text_passthrough

  async def call(
      self, content: AsyncIterable[ProcessorPart]
  ) -> AsyncIterable[ProcessorPart]:
    """Streams audio content parts from text parts.

    The order between TTS-processed parts and pass-through parts is not
    maintained. This processor treats all its inputs as realtime and sends them
    to the output as soon as possible. Order within TTS-processed an
    pass-through parts is maintained, but we don't wait for the TTS result to
    emit the next pass-through part.

    Args:
      content: The input stream of content to process. Non-text parts are passed
        through unchanged.

    Yields:
      The audio parts generated by the Text-To-Speech API from the text parts in
      `content`. If `with_text_passthrough` is True, the text parts are yielded
      back as well. All non-text parts are yielded unchanged.
    """
    # The output queue is used to yield the audio parts unchanged in the output
    # stream when _with_text_passthrough is True.
    output_queue = asyncio.Queue[ProcessorPart | None]()
    first_chunk_received = asyncio.Event()

    async def request_stream(
        request_queue: asyncio.Queue[
            texttospeech.StreamingSynthesizeRequest | None
        ],
    ):
      try:
        request_queue.put_nowait(self._config_request)
        async for part in content:
          if (
              not content_api.is_text(part.mimetype)
              or self._with_text_passthrough
          ):
            output_queue.put_nowait(part)
          if not content_api.is_text(part.mimetype) or not part.text:
            continue
          first_chunk_received.set()
          request_queue.put_nowait(
              texttospeech.StreamingSynthesizeRequest(
                  input=texttospeech.StreamingSynthesisInput(
                      text=part.text,
                  )
              )
          )
      finally:
        first_chunk_received.set()
        request_queue.put_nowait(None)

    async def send_text_to_speech_requests():
      try:
        request_queue = asyncio.Queue[
            texttospeech.StreamingSynthesizeRequest | None
        ]()
        enqueue_request_task = processor.create_task(
            request_stream(request_queue)
        )
        # Wait until the first request is sent to the Speech API to avoid the
        # client being created before the first request is sent.
        # The client can indeed only stay up for 5 seconds without any request
        # and then it needs to be re-created.
        await first_chunk_received.wait()
        client = texttospeech.TextToSpeechAsyncClient()
        streaming_responses = await client.streaming_synthesize(
            requests=streams.dequeue(request_queue)
        )

        async for response in streaming_responses:
          output_queue.put_nowait(
              ProcessorPart(
                  response.audio_content,
                  mimetype='audio/l16;rate=24000',
                  role='model',
              )
          )

        await enqueue_request_task
      finally:
        output_queue.put_nowait(None)

    send_task = processor.create_task(send_text_to_speech_requests())
    while part := await output_queue.get():
      yield part
    await send_task
