Runtime_env/app/orchestration/agent.py (461 lines of code) (raw):
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=C0301, W0107, W0107, W0622, R0917
"""Module used to define and interact with agent orchestrators."""
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Generator, Dict, Any, Optional
import uuid
from google.api_core import exceptions
from langchain.agents import (
AgentExecutor,
create_react_agent as langchain_create_react_agent
)
from langchain_core.messages import AIMessage
from langchain_google_vertexai import ChatVertexAI
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
from langgraph.prebuilt import create_react_agent as langgraph_create_react_agent
from vertexai.preview import reasoning_engines # TODO: update this when it becomes agent engine
# LlamaIndex
from llama_index.core import Settings
from llama_index.core.agent import ReActAgent
from llama_index.llms.langchain import LangChainLLM
from app.orchestration.constants import GEMINI_FLASH_20_LATEST
from app.orchestration.enums import OrchestrationFramework
from app.orchestration.tools import (
get_tools,
get_llamaindex_tools
)
class BaseAgentManager(ABC):
"""
Abstract base class for Agent Managers. Defines the common interface
for creating and managing agent executors with different orchestration
frameworks.
"""
def __init__(
self,
prompt: str,
industry_type: str,
location: str,
orchestration_framework: str,
agent_engine_resource_id: str,
model_name: str,
max_retries: int,
max_output_tokens: int,
temperature: float,
top_p: float,
top_k: int,
return_steps: bool,
verbose: bool
):
"""
Initializes the BaseAgentManager with common configurations.
Args:
prompt: System instructions to give to the agent.
industry_type: The agent industry type to use. Correlates to tool configs.
location: The GCP location to run the chat model.
orchestration_framework: The type of agent framework to use.
agent_engine_resource_id: The Resource ID of the deployed AE Agent (if using AE).
model_name: The valid name of the LLM to use for the agent.
max_retries: Maximum number of times to retry the query on a failure.
max_output_tokens: Maximum amount of text output from one prompt.
temperature: Temperature to use for the agent.
top_p: Top p value. Chooses the words based on a cumulative probability threshold.
top_k: Top k value. Chooses the top k most likely words
return_steps: Whether to return the agent's trajectory of intermediate
steps at the end in addition to the final output.
verbose: Whether or not run in verbose mode.
"""
self.prompt = prompt
self.industry_type = industry_type
self.location = location
self.orchestration_framework = orchestration_framework
self.agent_engine_resource_id = agent_engine_resource_id
self.model_name = model_name
self.max_retries = max_retries
self.max_output_tokens = max_output_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.return_steps = return_steps
self.verbose = verbose
self.model_obj = self.get_model_obj()
self.tools = self.get_tools()
self.agent_executor = self.create_agent_executor()
@abstractmethod
def create_agent_executor(self):
"""
Abstract method to create the specific agent executor based on the
orchestration framework. This must be implemented by subclasses.
Returns:
The initialized agent executor instance.
"""
pass
def get_tools(self):
"""
Helper method to retrieve tools based on the industry type.
Returns:
A list of tools for the agent to use, based on the industry type.
"""
return get_tools(
self.industry_type,
self.orchestration_framework
)
def get_model_obj(self):
"""
Helper method to retrieve the model object based on the model name
and config.
Returns:
An LLM object for the Agent to use.
Exception:
The model_name is not found.
"""
try:
if "claude" in self.model_name:
return ChatAnthropicVertex(
model_name=self.model_name,
max_retries=self.max_retries,
## causes issues with the pydantic model with default None value:
# max_output_tokens=self.max_output_tokens,
## for now, claude is only available in us-east5 and europe-west1
# location = self.location,
location="us-east5",
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
verbose=self.verbose
)
else:
return ChatVertexAI(
model_name=self.model_name,
location=self.location,
max_retries=self.max_retries,
max_output_tokens=self.max_output_tokens,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
verbose=self.verbose
)
except exceptions.NotFound as e:
raise exceptions.NotFound(f"Resource not found. {e}")
except Exception as e:
raise RuntimeError(f"Error encountered initalizing model resource. {e}") from e
@abstractmethod
async def astream(
self,
input: Dict[str, Any],
) -> AsyncGenerator[Dict, Any]:
"""
Abstract method to asynchronously stream the Agent output.
This should be implemented by subclasses to handle the specific
streaming logic of their agent executor.
"""
pass
class LangChainPrebuiltAgentManager(BaseAgentManager):
"""
AgentManager subclass for LangChain Agent orchestration.
"""
def __init__(
self,
prompt: str,
industry_type: str,
location: Optional[str] = "us-central1",
agent_engine_resource_id: Optional[str] = None,
model_name: Optional[str] = GEMINI_FLASH_20_LATEST,
max_retries: Optional[int] = 6,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = 0,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
return_steps: Optional[bool] = False,
verbose: Optional[bool] = True
):
super().__init__(
prompt=prompt,
industry_type=industry_type,
location=location,
orchestration_framework=OrchestrationFramework.LANGCHAIN_PREBUILT_AGENT.value,
agent_engine_resource_id=agent_engine_resource_id,
model_name=model_name,
max_retries=max_retries,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
return_steps=return_steps,
verbose=verbose
)
def create_agent_executor(self):
"""
Creates a Langchain React Agent executor.
"""
react_agent = langchain_create_react_agent(
prompt=self.prompt,
llm=self.model_obj,
tools=self.tools,
)
return AgentExecutor(
agent=react_agent,
tools=self.tools,
return_intermediate_steps=self.return_steps,
verbose=self.verbose
)
async def astream(
self,
input: Dict[str, Any],
) -> AsyncGenerator[Dict, Any]:
"""Asynchronously event streams the Agent output.
Args:
input: The list of messages to send to the model as input.
Yields:
Dictionaries representing the streamed agent output.
Exception:
An error is encountered during streaming.
"""
try:
async for chunk in self.agent_executor.astream(
{
"input": input["messages"][-1],
"chat_history": input["messages"][:-1],
}
):
# Organize response object to be consistent with other Agents (e.g. Agent Engine)
if "output" in chunk:
response_obj = {
"agent": {
"output": chunk["output"],
"messages": [
{
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"content": chunk["output"],
"type": "ai",
"tool_calls": [],
"invalid_tool_calls": [],
"id": input["run_id"]
}
}
]
}
}
yield response_obj
return # Exit the loop if successful
except Exception as e:
raise RuntimeError(f"Unexpected error. {e}") from e
class LangGraphPrebuiltAgentManager(BaseAgentManager):
"""
AgentManager subclass for LangGraph Agent orchestration.
"""
def __init__(
self,
prompt: str,
industry_type: str,
location: Optional[str] = "us-central1",
agent_engine_resource_id: Optional[str] = None,
model_name: Optional[str] = GEMINI_FLASH_20_LATEST,
max_retries: Optional[int] = 6,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = 0,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
return_steps: Optional[bool] = False,
verbose: Optional[bool] = True
):
super().__init__(
prompt=prompt,
industry_type=industry_type,
location=location,
orchestration_framework=OrchestrationFramework.LANGGRAPH_PREBUILT_AGENT.value,
agent_engine_resource_id=agent_engine_resource_id,
model_name=model_name,
max_retries=max_retries,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
return_steps=return_steps,
verbose=verbose
)
def create_agent_executor(self):
"""
Creates a LangGraph React Agent executor.
"""
return langgraph_create_react_agent(
prompt=self.prompt,
model=self.model_obj,
tools=self.tools,
debug=self.verbose
)
async def astream(
self,
input: Dict[str, Any],
) -> AsyncGenerator[Dict, Any]:
"""Asynchronously event streams the Agent output.
Args:
input: The list of messages to send to the model as input.
Yields:
Dictionaries representing the streamed agent output.
Exception:
An error is encountered during streaming.
"""
try:
async for chunk in self.agent_executor.astream(input, stream_mode="values"):
message = chunk["messages"][-1]
if isinstance(message, AIMessage):
# Organize response object to be consistent with other Agents (e.g. Agent Engine)
response_obj = {
"agent": {
"messages": [
{
"lc": 1,
"type": "constructor",
"id": ["langgraph", "schema", "messages", "AIMessage"],
"kwargs": {
"content": message.content,
"type": "ai",
"tool_calls": [],
"invalid_tool_calls": [],
"id": input["run_id"]
},
"usage_metadata": message.usage_metadata
}
]
}
}
yield response_obj
return # Exit the loop if successful
except Exception as e:
raise RuntimeError(f"Unexpected error. {e}") from e
class LangChainVertexAIAgentEngineAgentManager(BaseAgentManager):
"""
AgentManager subclass for Vertex AI Agent Engine LangChain orchestration.
"""
def __init__(
self,
prompt: str,
industry_type: str,
location: Optional[str] = "us-central1",
agent_engine_resource_id: Optional[str] = None,
model_name: Optional[str] = GEMINI_FLASH_20_LATEST,
max_retries: Optional[int] = 6,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = 0,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
return_steps: Optional[bool] = False,
verbose: Optional[bool] = True
):
super().__init__(
prompt=prompt,
industry_type=industry_type,
location=location,
orchestration_framework=OrchestrationFramework.LANGCHAIN_VERTEX_AI_AGENT_ENGINE_AGENT.value,
agent_engine_resource_id=agent_engine_resource_id,
model_name=model_name,
max_retries=max_retries,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
return_steps=return_steps,
verbose=verbose
)
def create_agent_executor(self):
"""
Creates a Vertex AI Agent Engine Langchain Agent executor.
"""
# If agent_engine_resource_id is provided, use the deployed Agent Engine
if self.agent_engine_resource_id:
langchain_agent = reasoning_engines.ReasoningEngine(self.agent_engine_resource_id)
else:
langchain_agent = reasoning_engines.LangchainAgent(
# prompt=self.prompt, # Custom prompt seems to have an issue for some reason
model=self.model_name,
tools=self.tools,
agent_executor_kwargs={
"return_intermediate_steps": self.return_steps,
"verbose": self.verbose
},
enable_tracing=True
)
langchain_agent.set_up()
return langchain_agent
async def astream(
self,
input: Dict[str, Any],
) -> AsyncGenerator[Dict, Any]:
"""
Asynchronously streams the Agent output using Vertex AI Agent engine.
Args:
input: The list of messages to send to the model as input.
Yields:
Dictionaries representing the streamed agent output.
Exception:
An error is encountered during streaming.
"""
# Convert the messages into a string
content = "\n".join(f"[{msg['type']}]: {msg['content']}" for msg in input["messages"])
try:
for chunk in self.agent_executor.stream_query(input=content):
chunk["messages"] = [{**msg, "kwargs": {**msg["kwargs"], "id": input["run_id"]}} for msg in chunk["messages"]]
yield {"agent": chunk}
except Exception as e:
raise RuntimeError(f"Unexpected error. {e}") from e
class LangGraphVertexAIAgentEngineAgentManager(BaseAgentManager):
"""
AgentManager subclass for Vertex AI Agent Engine LangGraph orchestration.
"""
def __init__(
self,
prompt: str,
industry_type: str,
location: Optional[str] = "us-central1",
agent_engine_resource_id: Optional[str] = None,
model_name: Optional[str] = GEMINI_FLASH_20_LATEST,
max_retries: Optional[int] = 6,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = 0,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
return_steps: Optional[bool] = False,
verbose: Optional[bool] = True
):
super().__init__(
prompt=prompt,
industry_type=industry_type,
location=location,
orchestration_framework=OrchestrationFramework.LANGGRAPH_VERTEX_AI_AGENT_ENGINE_AGENT.value,
agent_engine_resource_id=agent_engine_resource_id,
model_name=model_name,
max_retries=max_retries,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
return_steps=return_steps,
verbose=verbose
)
def create_agent_executor(self):
"""
Creates a Vertex AI Agent Engine LangGraph Agent executor.
"""
# If agent_engine_resource_id is provided, use the deployed Agent Engine
if self.agent_engine_resource_id:
langgraph_agent = reasoning_engines.ReasoningEngine(self.agent_engine_resource_id)
else:
langgraph_agent = reasoning_engines.LanggraphAgent(
model=self.model_name,
tools=self.tools,
runnable_kwargs={"prompt": self.prompt, "debug": self.verbose},
enable_tracing=True
)
langgraph_agent.set_up()
return langgraph_agent
async def astream(
self,
input: Dict[str, Any],
) -> AsyncGenerator[Dict, Any]:
"""
Asynchronously streams the Agent output using Vertex AI Agent Engine.
Args:
input: The list of messages to send to the model as input.
Yields:
Dictionaries representing the streamed agent output.
Exception:
An error is encountered during streaming.
"""
try:
for chunk in self.agent_executor.stream_query(
input=input,
stream_mode="values"
):
# Override the each of the run_ids with the one from the server
chunk["messages"] = [{**msg, "kwargs": {**msg["kwargs"], "id": input["run_id"]}} for msg in chunk["messages"]]
if self.agent_engine_resource_id:
if chunk["messages"][-1]["id"][-1] == "AIMessage":
yield {"agent": chunk}
else:
yield {"agent": chunk}
except Exception as e:
raise RuntimeError(f"Unexpected error. {e}") from e
class LlamaIndexAgentManager(BaseAgentManager):
"""
AgentManager subclass for LangGraph Agent orchestration.
"""
def __init__(
self,
prompt: str,
industry_type: str,
location: Optional[str] = "us-central1",
agent_engine_resource_id: Optional[str] = None,
model_name: Optional[str] = GEMINI_FLASH_20_LATEST,
max_retries: Optional[int] = 6,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = 0,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
return_steps: Optional[bool] = False,
verbose: Optional[bool] = True
):
super().__init__(
prompt=prompt,
industry_type=industry_type,
location=location,
orchestration_framework=OrchestrationFramework.LLAMAINDEX_AGENT.value,
agent_engine_resource_id=agent_engine_resource_id,
model_name=model_name,
max_retries=max_retries,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
return_steps=return_steps,
verbose=verbose
)
self.model_obj = LangChainLLM(self.model_obj)
def get_tools(self):
"""
Helper method to retrieve tools based on the industry type.
Returns:
A list of tools for the agent to use, based on the industry type.
"""
return get_llamaindex_tools(self.industry_type)
def create_agent_executor(self):
"""
Creates a LlamaIndex React Agent executor.
"""
# setup the index/query llm for Vertex Search
Settings.llm = LangChainLLM(self.model_obj)
llamaindex_agent = ReActAgent.from_tools(
prompt=self.prompt,
llm = LangChainLLM(self.model_obj),
tools=self.tools,
verbose=self.verbose
)
return llamaindex_agent
def get_response_obj(
self,
content: str,
run_id: str
) -> Dict:
"""Returns a structure dictionary response object.
The response object needs to be organized to be
consistent with other Agents (e.g. Agent Engine)
Args:
content: The string reponse from the LLM.
run_id: The run_id.
Returns:
Structured dictionary reponse object.
"""
response_obj = {
"agent": {
"messages": [
{
"lc": 1,
"type": "constructor",
"id": ["llamaindex", "schema", "messages", "AIMessage"],
"kwargs": {
"content": content,
"type": "ai",
"tool_calls": [],
"invalid_tool_calls": [],
"id": run_id
},
}
]
}
}
return response_obj
async def astream(
self,
input: Dict[str, Any],
) -> AsyncGenerator[Dict, Any]:
"""Asynchronously event streams the Agent output. Needs to be
an Iterable for Agent Engine deployments
Args:
input: The list of messages to send to the model as input.
Yields:
Dictionaries representing the streamed agent output.
Exception:
An error is encountered during streaming.
"""
try:
if self.agent_engine_resource_id:
llamaindex_agent = reasoning_engines.ReasoningEngine(self.agent_engine_resource_id)
# Uses the stream_query function below
response = llamaindex_agent.stream_query(input=input)
for chunk in response:
yield chunk
else:
# Convert the messages into a string
content = "\n".join(f"[{msg['type']}]: {msg['content']}" for msg in input["messages"])
response = await self.agent_executor.aquery(content)
yield self.get_response_obj(
content=response.response,
run_id=input["run_id"]
)
except Exception as e:
raise RuntimeError(f"Unexpected error. {e}") from e
# # alias stream_query for Agent Engine deployments
# stream_query = astream
def stream_query(
self,
input: Dict[str, Any]
) -> Generator:
"""Synchronously event streams the Agent output.
This function is an alias for astream for Agent Engine deployments
Args:
input: The list of messages to send to the model as input.
Yields:
Dictionaries representing the streamed agent output.
Exception:
An error is encountered during processing.
"""
run_id = uuid.uuid4()
# If using Agent Engine, explicitly re-set the index/query llm for Vertex Search
Settings.llm = self.model_obj
try:
# Convert the messages into a string
content = "\n".join(f"[{msg['type']}]: {msg['content']}" for msg in input["messages"])
response = self.agent_executor.query(content)
yield self.get_response_obj(
content=response.response,
run_id=run_id
)
except Exception as e:
raise RuntimeError(f"Unexpected error. {e}") from e