gemini/agents/genai-experience-concierge/langgraph-demo/backend/concierge/utils.py (134 lines of code) (raw):
# Copyright 2025 Google. This software is provided as-is, without warranty or
# representation for any use or purpose. Your use of it is subject to your
# agreement with Google.
"""Utilities for Gen AI SDK."""
import asyncio
import inspect
import logging
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Mapping,
ParamSpec,
TypeVar,
)
from concierge import schemas
from google import genai
from google.genai import errors as genai_errors
from google.genai import types as genai_types
from langgraph import graph
import pydantic
import requests
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
logger = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
def load_graph(
schema: type,
nodes: list[schemas.Node],
entry_point: schemas.Node,
) -> graph.StateGraph:
"""Load a state graph from a list of nodes.
Note: This function only works with "edgeless" graphs,
which use the Command object to specify the next node to trnasition to.
"""
state_graph = graph.StateGraph(state_schema=schema)
for node in nodes:
state_graph.add_node(node.name, node.fn)
state_graph.set_entry_point(entry_point.name)
return state_graph
def load_user_content(current_turn: schemas.BaseTurn) -> genai_types.Content:
"""Load user input from current turn into a Content object."""
user_input = current_turn.get("user_input")
assert user_input is not None, "user input must be set"
user_content = genai_types.Content(
role="user",
parts=[genai_types.Part.from_text(text=user_input)],
)
return user_content
def is_retryable_error(exception: BaseException) -> bool:
"""
Determines if a given exception is considered retryable.
This function checks if the provided exception is an API error with a retryable HTTP status code
(429, 502, 503, 504) or a connection error.
Args:
exception: The exception to evaluate.
Returns:
True if the exception is retryable, False otherwise.
"""
if isinstance(exception, genai_errors.APIError):
return exception.code in [429, 502, 503, 504]
if isinstance(exception, requests.exceptions.ConnectionError):
return True
return False
def default_retry(func: Callable[P, T]) -> Callable[P, T]:
"""Defines a default retry strategy for Gemini invocation, with exponential backoff."""
return retry(
retry=retry_if_exception(is_retryable_error),
wait=wait_exponential(min=1, max=10),
stop=stop_after_attempt(3),
reraise=True,
)(func)
# pylint: disable=too-many-arguments,too-many-positional-arguments
async def generate_content_stream(
model: str,
contents: list[genai_types.Content],
config: genai_types.GenerateContentConfig,
client: genai.Client,
max_recursion_depth: int = 3,
fn_map: dict[str, Callable] | None = None,
) -> AsyncGenerator[genai_types.Content, None]:
"""
Streams generated content from a Gemini model, handling function calls within the stream.
This function iteratively generates content from a Gemini model, processing function calls
encountered during generation. It executes these function calls asynchronously and feeds
their results back to the model for continued generation.
Args:
model: The name of the Gemini model to use.
contents: The list of Content objects representing the conversation history.
config: The GenerateContentConfig for the model.
client: The Gemini client.
max_recursion_depth: The maximum depth of recursive function calls to prevent infinite loops.
fn_map: A mapping of function names to their corresponding callable functions.
Yields:
Content objects representing the generated content, including text and function call responses.
""" # pylint: disable=line-too-long
fn_map = fn_map or {}
if max_recursion_depth < 0:
logger.warning("Maximum depth reached, stopping generation.")
return
response: AsyncIterator[genai_types.GenerateContentResponse] = (
await client.aio.models.generate_content_stream(
model=model,
contents=contents,
config=config,
)
)
# iterate over chunk in main request
async for chunk in response:
if chunk.candidates is None or chunk.candidates[0].content is None:
logger.warning("no candidates or content, skipping chunk.")
continue
# yield current chunk content (assume only one candidate)
content = chunk.candidates[0].content
yield content
# if any function calls:
# - execute each in parallel
# - then call generate after responses are gathered
if chunk.function_calls:
# create asyncio tasks to execute each function call
tasks = list[asyncio.Task[dict[str, Any]]]()
for function_call in chunk.function_calls:
if function_call.name is None:
logger.warning("skipping function call without name")
continue
if function_call.name not in fn_map:
raise RuntimeError(
f"Function not provided in fn_map: {function_call.name}"
)
func = fn_map[function_call.name]
kwargs = function_call.args or {}
tasks.append(asyncio.create_task(run_function_async(func, kwargs)))
fn_results = await asyncio.gather(*tasks)
# create and yield content from function responses
fn_response_content = genai_types.Content(
role="user",
parts=[
genai_types.Part.from_function_response(
name=fn_call.name, response=fn_result
)
for fn_call, fn_result in zip(chunk.function_calls, fn_results)
],
)
yield fn_response_content
# continue generation and yield resulting content
async for content in generate_content_stream(
model=model,
contents=contents
+ [
content.model_copy(deep=True),
fn_response_content.model_copy(deep=True),
],
config=config,
client=client,
max_recursion_depth=max_recursion_depth - 1,
fn_map=fn_map,
):
yield content
# pylint: enable=too-many-arguments,too-many-positional-arguments
async def run_function_async(
function: Callable[..., pydantic.BaseModel | Awaitable[pydantic.BaseModel]],
function_kwargs: Mapping[str, Any],
) -> dict[str, str | dict]:
"""
Runs a function asynchronously and wraps the results for google-genai FunctionResponse.
This function executes a given function asynchronously, handling both synchronous and asynchronous functions.
Note: Sync functions are made asynchronous by running in the default threadpool executor so any sync functions should be thread-safe.
Args:
function: The function to execute.
function_kwargs: The arguments to pass to the function.
Returns:
A dictionary containing the function's result or an error message.
""" # pylint: disable=line-too-long
try:
if inspect.iscoroutinefunction(function):
fn_result = await function(**function_kwargs)
else:
loop = asyncio.get_running_loop()
fn_result = await loop.run_in_executor(
None,
lambda kwargs: function(**kwargs),
function_kwargs,
)
return {"result": fn_result.model_dump(mode="json")}
except Exception as e: # pylint: disable=broad-exception-caught
return {"error": str(e)}