gemini/agents/genai-experience-concierge/langgraph-demo/backend/concierge/nodes/router.py (109 lines of code) (raw):

# Copyright 2025 Google. This software is provided as-is, without warranty or # representation for any use or purpose. Your use of it is subject to your # agreement with Google. """Semantic router node for routing user queries to sub-agents.""" import enum import logging from typing import Literal, TypedDict, TypeVar from concierge import schemas, utils from google import genai from google.genai import types as genai_types from langchain_core.runnables import config as lc_config from langgraph import types as lg_types from langgraph.config import get_stream_writer import pydantic logger = logging.getLogger(__name__) RouterTarget = TypeVar("RouterTarget", bound=enum.StrEnum) class RouterClassification(pydantic.BaseModel): """Structured classification output for routing user queries.""" reason: str """Explanation of why the query was classified to a specific target.""" target: str """The target node to route the query to.""" class RouterTurn(schemas.BaseTurn): """Represents a single turn in a conversation.""" router_classification: RouterClassification | None """The router classification for the current turn.""" class RouterState(TypedDict, total=False): """Stores the active turn and conversation history.""" current_turn: RouterTurn | None """The current conversation turn.""" turns: list[RouterTurn] """List of all conversation turns in the session.""" class RouterConfig(pydantic.BaseModel): """Configuration settings for the router node.""" project: str """The Google Cloud project ID.""" region: str """The Google Cloud region.""" router_model_name: str """The name of the Gemini intent detection model.""" max_router_turn_history: int """The maximum number of prior turns to include in the conversation history.""" def build_semantic_router_node( node_name: str, system_prompt: str, class_node_mapping: dict[str, str], ) -> schemas.Node: """ Builds a LangGraph node that can dynamically route between sub-agents based on user intent. """ # ignore typing errors, this creates a valid literal type NextNodeT = Literal[*class_node_mapping.values()] # type: ignore response_schema = genai_types.Schema( properties={ "target": genai_types.Schema( type=genai_types.Type.STRING, enum=class_node_mapping.keys(), description="The target node to route the query to.", nullable=False, ), "reason": genai_types.Schema( type=genai_types.Type.STRING, description="Reason for classifying the latest user query.", nullable=False, ), }, required=["reason", "target"], type=genai_types.Type.OBJECT, property_ordering=["reason", "target"], ) async def ainvoke( state: RouterState, config: lc_config.RunnableConfig, ) -> lg_types.Command[NextNodeT]: """ Asynchronously invokes the router node to classify user input and determine the next action. This function takes the current conversation state and configuration, interacts with the Gemini model to classify the user's input based on predefined categories, and determines which sub-agent should handle the request. Runtime configuration should be passed in `config.configurable.router_config`. Args: state: The current state of the conversation session, including user input and history. config: The LangChain RunnableConfig containing agent-specific configurations. Returns: A Command object that specifies the next node to transition to and the updated conversation state. This state includes the router classification. """ router_config = RouterConfig.model_validate( config.get("configurable", {}).get("router_config", {}) ) stream_writer = get_stream_writer() current_turn = state.get("current_turn") assert current_turn is not None, "current turn must be set" # Initialize generate model client = genai.Client( vertexai=True, project=router_config.project, location=router_config.region, ) user_content = utils.load_user_content(current_turn=current_turn) turns = state.get("turns", []) contents = [ content for turn in turns[-router_config.max_router_turn_history :] # noqa: E203 for content in turn.get("messages", []) ] + [user_content] # generate streaming response response = await client.aio.models.generate_content( model=router_config.router_model_name, contents=contents, config=genai_types.GenerateContentConfig( candidate_count=1, temperature=0.2, seed=0, system_instruction=system_prompt, response_mime_type="application/json", response_schema=response_schema, ), ) router_classification = RouterClassification.model_validate_json( response.text or "" ) stream_writer( { "router_classification": { "target": router_classification.target, "reason": router_classification.reason, } } ) current_turn["router_classification"] = router_classification next_node = None for target_value, target_node in class_node_mapping.items(): if router_classification.target == target_value: next_node = target_node break else: raise RuntimeError( f"Unhandled router classification target: {router_classification.target}" ) return lg_types.Command( update=RouterState(current_turn=current_turn), goto=next_node, ) return schemas.Node(name=node_name, fn=ainvoke)