core/genai_model.py (63 lines of code) (raw):

# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Wraps the Gemini API into a Processor. ## Example usage ```py p = GenaiModel( api_key=API_KEY, model_name="gemini-2.0-flash-exp-image-generation", generate_content_config=genai.types.GenerateContentConfig( response_modalities=['Text', 'Image'] ) ) ``` ### Sync Execution ```py INPUT_PROMPT = 'Create an image of two dalmatians & a cute poem' genai_content = processors.apply_sync(p, [ProcessorPart(INPUT_PROMPT)]) for part in genai_content: if part.text: print(part.text) elif part.pil_image: display(part.pil_image) ``` ### Async Execution ```py input_stream = processor.stream_content([ProcessorPart(INPUT_PROMPT)]) async for part in p(input_stream): if part.text: print(part.text) elif part.pil_image: display(part.pil_image) ``` """ from collections.abc import AsyncIterable from typing import Any from genai_processors import content_api from genai_processors import processor from google.genai import client from google.genai import types as genai_types def genai_response_to_metadata( response: genai_types.GenerateContentResponse, ) -> dict[str, Any]: """Converts a Genai response to metadata, to be attached to a ProcessorPart.""" return { "create_time": response.create_time, "response_id": response.response_id, "model_version": response.model_version, "prompt_feedback": response.prompt_feedback, "usage_metadata": response.usage_metadata, "automatic_function_calling_history": ( response.automatic_function_calling_history ), "parsed": response.parsed, } class GenaiModel(processor.Processor): """`Processor` that calls the Genai API in turn-based fashion. Note: All content is buffered prior to calling the Genai API. """ def __init__( self, api_key: str, model_name: str, generate_content_config: ( genai_types.GenerateContentConfigOrDict | None ) = None, debug_config: client.DebugConfig | None = None, http_options: ( genai_types.HttpOptions | genai_types.HttpOptionsDict | None ) = None, ): """Initializes the GenaiModel. Args: api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to use for authentication. Applies to the Gemini Developer API only. model_name: The name of the model to use. generate_content_config: The configuration for generating content. debug_config: Config settings that control network behavior of the client. This is typically used when running test code. http_options: Http options to use for the client. These options will be applied to all requests made by the client. Example usage: `client = genai.Client(http_options=types.HttpOptions(api_version='v1'))`. Returns: A `Processor` that calls the Genai API in turn-based fashion. ## Model Name Usage Supported formats for **Vertex AI API** include: * The Gemini model ID, for example: 'gemini-2.0-flash' * The full resource name starts with 'projects/', for example: 'projects/my-project-id/locations/us-central1/publishers/google/models/gemini-2.0-flash' * The partial resource name with 'publishers/', for example: 'publishers/google/models/gemini-2.0-flash' or 'publishers/meta/models/llama-3.1-405b-instruct-maas' / separated publisher and model name, for example: 'google/gemini-2.0-flash' or 'meta/llama-3.1-405b-instruct-maas' Supported formats for **Gemini API** include: * The Gemini model ID, for example: 'gemini-2.0-flash' * The model name starts with 'models/', for example: 'models/gemini-2.0-flash' * For tuned models, the model name starts with 'tunedModels/', for example: 'tunedModels/1234567890123456789' """ self._client = client.Client( api_key=api_key, debug_config=debug_config, http_options=http_options, ) self._model_name = model_name self._generate_content_config = generate_content_config async def call( self, content: AsyncIterable[content_api.ProcessorPartTypes] ) -> AsyncIterable[content_api.ProcessorPartTypes]: contents = [] async for content_part in content: contents.append(content_api.to_genai_part(content_part)) if not contents: return async for res in await self._client.aio.models.generate_content_stream( model=self._model_name, contents=contents, config=self._generate_content_config, ): res: genai_types.GenerateContentResponse = res if res.candidates: content = res.candidates[0].content if content and content.parts: for part in content.parts: yield processor.ProcessorPart( part, metadata=genai_response_to_metadata(res), role=content.role or "model", )