gemini/agents/genai-experience-concierge/langgraph-demo/backend/concierge/nodes/guardrails.py (112 lines of code) (raw):

# Copyright 2025 Google. This software is provided as-is, without warranty or # representation for any use or purpose. Your use of it is subject to your # agreement with Google. """Node to classify user input and determine the next action.""" import logging from typing import Literal, TypedDict from concierge import schemas, utils from google import genai from google.genai import types as genai_types from langchain_core.runnables import config as lc_config from langgraph import types as lg_types from langgraph.config import get_stream_writer import pydantic logger = logging.getLogger(__name__) DEFAULT_FALLBACK_RESPONSE = "I apologize, but I am unable to assist with this query as it falls outside the scope of my knowledge base." # pylint: disable=line-too-long class InputGuardrails(pydantic.BaseModel): """ Represents the classification of a user request by the guardrails system. Attributes: blocked: Indicates whether the request should be blocked. reason: The reason for the classification decision. guardrail_response: A fallback message to be returned if the request is blocked. """ # pylint: disable=line-too-long blocked: bool = pydantic.Field( description="The classification decision on whether the request should be blocked.", ) """Boolean indicating whether the request should be blocked.""" reason: str = pydantic.Field( description="Reason why the response was given the classification value.", ) """Explanation of why the request was classified as blocked or allowed.""" guardrail_response: str = pydantic.Field( description=( "Guardrail fallback message if the response is blocked." " Should be safe to surface to users." ), ) """A safe message to display to the user if their request is blocked.""" class GuardrailTurn(schemas.BaseTurn): """Represents a single turn in a conversation with guardrails.""" input_guardrails: InputGuardrails | None """The guardrail classification for this turn, if any.""" class GuardrailState(TypedDict, total=False): """Stores the active turn and conversation history.""" current_turn: GuardrailTurn | None """The current turn being processed.""" turns: list[GuardrailTurn] """List of all turns in the session.""" class GuardrailConfig(pydantic.BaseModel): """Configuration settings for the guardrails node.""" project: str region: str guardrail_model_name: str def build_guardrail_node( node_name: str, allowed_next_node: str, blocked_next_node: str, system_prompt: str, guardrail_fallback_response: str = DEFAULT_FALLBACK_RESPONSE, ) -> schemas.Node: """Builds a LangGraph node for classifying user input as blocked or allowed.""" NextNodeT = Literal[allowed_next_node, blocked_next_node] # type: ignore async def ainvoke( state: GuardrailState, config: lc_config.RunnableConfig, ) -> lg_types.Command[NextNodeT]: """ Asynchronously invokes the guardrails node to classify user input and determine the next action. This function classifies the user's input based on predefined guardrails, determining whether the input should be blocked or allowed. If blocked, a guardrail response is generated and the conversation is directed to the post-processing node. If allowed, the conversation proceeds to the chat node. Args: state: The current state of the conversation session, including user input and history. config: The LangChain RunnableConfig containing agent-specific configurations. Returns: A Command object that specifies the next node to transition to and the updated conversation state. This state includes the guardrail classification and the appropriate response to the user. """ stream_writer = get_stream_writer() guardrail_config = GuardrailConfig.model_validate( config.get("configurable", {}).get("guardrail_config", {}) ) current_turn = state.get("current_turn") assert current_turn is not None, "current turn must be set" # Initialize generate model client = genai.Client( vertexai=True, project=guardrail_config.project, location=guardrail_config.region, ) user_content = utils.load_user_content(current_turn=current_turn) contents = [ content for turn in state.get("turns", []) for content in turn.get("messages", []) ] + [user_content] try: # generate streaming response response = await client.aio.models.generate_content( model=guardrail_config.guardrail_model_name, contents=contents, config=genai_types.GenerateContentConfig( system_instruction=system_prompt, candidate_count=1, temperature=0, seed=0, response_mime_type="application/json", response_schema=InputGuardrails, ), ) guardrail_classification = InputGuardrails.model_validate_json( response.text.strip() if response.text else "" ) except Exception as e: # pylint: disable=broad-exception-caught logger.exception(e) error_reason = str(e) guardrail_classification = InputGuardrails( blocked=True, reason=error_reason, guardrail_response=( "An error occurred during response generation." " Please try again later." ), ) stream_writer( { "guardrail_classification": guardrail_classification.model_dump( mode="json" ) } ) # Update current response with classification and default guardrail response current_turn["response"] = guardrail_fallback_response current_turn["input_guardrails"] = guardrail_classification # If request is not allowed, set current agent response to generative fallback. if ( guardrail_classification.blocked and guardrail_classification.guardrail_response is not None ): current_turn["response"] = guardrail_classification.guardrail_response # determine next node and stream fallback response if blocked. next_node = allowed_next_node if guardrail_classification.blocked: stream_writer({"text": current_turn["response"]}) next_node = blocked_next_node return lg_types.Command( update=GuardrailState(current_turn=current_turn), goto=next_node, ) return schemas.Node(name=node_name, fn=ainvoke)