# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Wraps the Gemini API into a Processor.

## Example usage

```py
p = GenaiModel(
    api_key=API_KEY,
    model_name="gemini-2.0-flash-exp-image-generation",
    generate_content_config=genai.types.GenerateContentConfig(
      response_modalities=['Text', 'Image']
    )
)
```

### Sync Execution

```py
INPUT_PROMPT = 'Create an image of two dalmatians & a cute poem'

genai_content = processors.apply_sync(p, [ProcessorPart(INPUT_PROMPT)])
for part in genai_content:
  if part.text:
    print(part.text)
  elif part.pil_image:
    display(part.pil_image)
```

### Async Execution

```py
input_stream = processor.stream_content([ProcessorPart(INPUT_PROMPT)])
async for part in p(input_stream):
  if part.text:
    print(part.text)
  elif part.pil_image:
    display(part.pil_image)
```
"""

from collections.abc import AsyncIterable
from typing import Any
from genai_processors import content_api
from genai_processors import processor
from google.genai import client
from google.genai import types as genai_types


def genai_response_to_metadata(
    response: genai_types.GenerateContentResponse,
) -> dict[str, Any]:
  """Converts a Genai response to metadata, to be attached to a ProcessorPart."""
  return {
      "create_time": response.create_time,
      "response_id": response.response_id,
      "model_version": response.model_version,
      "prompt_feedback": response.prompt_feedback,
      "usage_metadata": response.usage_metadata,
      "automatic_function_calling_history": (
          response.automatic_function_calling_history
      ),
      "parsed": response.parsed,
  }


class GenaiModel(processor.Processor):
  """`Processor` that calls the Genai API in turn-based fashion.

  Note: All content is buffered prior to calling the Genai API.
  """

  def __init__(
      self,
      api_key: str,
      model_name: str,
      generate_content_config: (
          genai_types.GenerateContentConfigOrDict | None
      ) = None,
      debug_config: client.DebugConfig | None = None,
      http_options: (
          genai_types.HttpOptions | genai_types.HttpOptionsDict | None
      ) = None,
  ):
    """Initializes the GenaiModel.

    Args:
      api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to
        use for authentication. Applies to the Gemini Developer API only.
      model_name: The name of the model to use.
      generate_content_config: The configuration for generating content.
      debug_config: Config settings that control network behavior of the client.
        This is typically used when running test code.
      http_options: Http options to use for the client. These options will be
        applied to all requests made by the client. Example usage: `client =
        genai.Client(http_options=types.HttpOptions(api_version='v1'))`.

    Returns:
      A `Processor` that calls the Genai API in turn-based fashion.

    ## Model Name Usage

    Supported formats for **Vertex AI API** include:

      * The Gemini model ID, for example: 'gemini-2.0-flash'
      * The full resource name starts with 'projects/', for example:
      'projects/my-project-id/locations/us-central1/publishers/google/models/gemini-2.0-flash'
      * The partial resource name with 'publishers/', for example:
      'publishers/google/models/gemini-2.0-flash' or
      'publishers/meta/models/llama-3.1-405b-instruct-maas' / separated
      publisher and model name, for example: 'google/gemini-2.0-flash' or
      'meta/llama-3.1-405b-instruct-maas'

    Supported formats for **Gemini API** include:

      * The Gemini model ID, for example: 'gemini-2.0-flash'
      * The model name starts with 'models/', for example:
      'models/gemini-2.0-flash'
      * For tuned models, the model name starts with 'tunedModels/', for
      example: 'tunedModels/1234567890123456789'
    """
    self._client = client.Client(
        api_key=api_key,
        debug_config=debug_config,
        http_options=http_options,
    )
    self._model_name = model_name
    self._generate_content_config = generate_content_config

  async def call(
      self, content: AsyncIterable[content_api.ProcessorPartTypes]
  ) -> AsyncIterable[content_api.ProcessorPartTypes]:
    contents = []
    async for content_part in content:
      contents.append(content_api.to_genai_part(content_part))

    if not contents:
      return

    async for res in await self._client.aio.models.generate_content_stream(
        model=self._model_name,
        contents=contents,
        config=self._generate_content_config,
    ):
      res: genai_types.GenerateContentResponse = res
      if res.candidates:
        content = res.candidates[0].content
        if content and content.parts:
          for part in content.parts:
            yield processor.ProcessorPart(
                part,
                metadata=genai_response_to_metadata(res),
                role=content.role or "model",
            )
