orchestration/strategies/multimodal_agent_strategy.py (151 lines of code) (raw):
import base64
import json
import logging
import re
from typing import Annotated, Sequence
from pydantic import BaseModel
from autogen_agentchat.agents import AssistantAgent, BaseChatAgent
from autogen_agentchat.base._chat_agent import Response
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
MultiModalMessage,
TextMessage,
ToolCallSummaryMessage,
)
from autogen_core import CancellationToken, Image
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.tools import FunctionTool
from connectors import BlobClient
from tools import get_time, get_today_date, multimodal_vector_index_retrieve
from tools.ragindex.types import MultimodalVectorIndexRetrievalResult
from ..constants import Strategy
from .base_agent_strategy import BaseAgentStrategy
class MultimodalMessageCreator(BaseChatAgent):
"""
A custom agent that constructs a MultiModalMessage from the vector_index_retrieve_wrapper
tool results. The actual tool call is done by another agent
(retrieval_agent) in the same SelectorGroupChat.
This agent simply scans the conversation for the tool's result, parses
text + image URLs, and returns a MultiModalMessage.
"""
def __init__(self, name: str, system_prompt: str, model_context: BufferedChatCompletionContext):
super().__init__(
name=name,
description="An agent that creates `MultiModalMessage` objects from the results of `vector_index_retrieve_wrapper`, executed by an `AssistantAgent` called `retrieval_agent`."
)
self._last_multimodal_result = None
self.system_prompt = system_prompt + "\n\n"
self._model_context = model_context
@property
def produced_message_types(self):
"""
Return the message types this agent can produce.
We produce a MultiModalMessage.
"""
return (MultiModalMessage,)
async def on_messages(
self,
messages: Sequence[ChatMessage],
cancellation_token: CancellationToken
) -> Response:
"""
Handles incoming messages to process vector index retrieval results
and construct a MultiModalMessage response.
"""
retrieval_data = None
# Iterate through messages in reverse to find the latest relevant tool output
for msg in reversed(messages):
if isinstance(msg, ToolCallSummaryMessage):
try:
msg_content = msg.content
parsed_content = json.loads(msg_content)
if "texts" in parsed_content or "images" in parsed_content:
retrieval_data = parsed_content
break
except json.JSONDecodeError as e:
logging.warning(f"Failed to parse message content as JSON: {e}")
continue
if not retrieval_data:
# Fallback response when no relevant data is found
fallback_msg = TextMessage(
content="No vector retrieval data was found in the conversation.",
source=self.name
)
return Response(chat_message=fallback_msg)
# Extract text and image data
texts = retrieval_data.get("texts", [])
image_urls_list = retrieval_data.get("images", [])
captions_lists = retrieval_data.get("captions", [])
# Combine text snippets into a single string
combined_text = self.system_prompt + "\n\n".join(texts) if texts else "No text results"
logging.debug(f"[multimodal_agent_strategy] combined_text: {combined_text}")
# Fetch images from URLs
image_objects = []
max_images = 50 # maximum number of images to process (Azure OpenaI GPT-4o limit)
image_count = 0
document_count = 0
for image_urls_list in image_urls_list: # Assuming each item in image_urls is a list of URLs
image_count = 0
for url in image_urls_list: # Iterate through each URL in the sublist
if image_count >= max_images:
logging.info(f"[multimodal_agent_strategy] Reached the maximum image limit of {max_images}. Stopping further downloads.")
break # Stop processing more URLs
try:
# Initialize BlobClient with the blob URL
blob_client = BlobClient(blob_url=url)
logging.debug(f"[multimodal_agent_strategy] Initialized BlobClient for URL: {url}")
# Download the blob data as bytes
blob_data = blob_client.download_blob()
logging.debug(f"[multimodal_agent_strategy] Downloaded blob data for URL: {url}")
# Open the image using PIL
base64_str = base64.b64encode(blob_data).decode('utf-8')
pil_img = Image.from_base64(base64_str)
logging.debug(f"[multimodal_agent_strategy] Opened image from URL: {url}")
uri = re.sub(r'https://[^/]+\.blob\.core\.windows\.net', '', url)
pil_img.filepath = uri
logging.debug(f"[multimodal_agent_strategy] Filepath (uri): {uri}")
pil_img.caption = captions_lists[document_count][image_count] if image_count < len(image_urls_list) else None
# Append the PIL Image object to your list (modify as needed)
image_objects.append(pil_img)
image_count += 1 # Increment the counter
logging.info(f"[multimodal_agent_strategy] Successfully loaded image from {url}")
except Exception as e:
logging.error(f"[multimodal_agent_strategy] Could not load image from {url}: {e}")
image_count += 1 # Increment the image number for the next set of URLs
document_count += 1 # Increment the document number for the next set of URLs
# Construct and return the MultiModalMessage response
multimodal_msg = MultiModalMessage(
content=[combined_text, *image_objects],
source=self.name
)
return Response(chat_message=multimodal_msg)
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""
Reset the agent state if needed.
In this basic example, we clear the internal variable.
"""
self._last_multimodal_result = None
class MultimodalAgentStrategy(BaseAgentStrategy):
def __init__(self):
super().__init__()
self.strategy_type = Strategy.MULTIMODAL_RAG
async def create_agents(self, history, client_principal=None, access_token=None, output_mode=None, output_format=None):
"""
Multimodal RAG creation strategy that creates the basic agents and registers functions.
Parameters:
- history: The conversation history, which will be summarized to provide context for the assistant's responses.
Returns:
- agent_configuration: A dictionary that includes the agents team, default model client, termination conditions and selector function.
Note:
To use a different model for an specific agent, instantiate a separate AzureOpenAIChatCompletionClient and assign it instead of using self._get_model_client().
"""
# Model Context
shared_context = await self._get_model_context(history)
# Wrapper Functions for Tools
async def vector_index_retrieve_wrapper(
input: Annotated[str, "An optimized query string based on the user's ask and conversation history, when available"]
) -> MultimodalVectorIndexRetrievalResult:
return await multimodal_vector_index_retrieve(input, self._generate_security_ids(client_principal))
vector_index_retrieve_tool = FunctionTool(
vector_index_retrieve_wrapper, name="vector_index_retrieve", description="Performs a vector search using Azure AI Search fetching text and related images get relevant sources for answering the user's query."
)
# Agents
## Triage Agent
triage_prompt = await self._read_prompt("triage_agent")
triage_agent = AssistantAgent(
name="triage_agent",
system_message=triage_prompt,
model_client=self._get_model_client(),
tools=[vector_index_retrieve_tool],
reflect_on_tool_use=False,
model_context=shared_context
)
## Multimodal Message Creator
multimodal_rag_message_prompt = await self._read_prompt("multimodal_rag_message")
multimodal_creator = MultimodalMessageCreator(name="multimodal_creator", system_prompt=multimodal_rag_message_prompt, model_context=shared_context)
## Assistant Agent
main_assistant_prompt = await self._read_prompt("main_assistant")
main_assistant = AssistantAgent(
name="main_assistant",
system_message=main_assistant_prompt,
model_client=self._get_model_client(),
reflect_on_tool_use=True
)
## Chat Closure Agent
chat_closure = await self._create_chat_closure_agent(output_format, output_mode)
# Agent Configuration
# Optional: Override the termination condition for the assistant. Set None to disable each termination condition.
# self.max_rounds = int(os.getenv('MAX_ROUNDS', 8))
# self.terminate_message = "TERMINATE"
def custom_selector_func(messages):
"""
Selects the next agent based on the source of the last message.
Transition Rules:
user -> triage_agent
triage_agent (ToolCallSummaryMessage) -> multimodal_creator
multimodal_creator -> assistant
Other -> None (SelectorGroupChat will handle transition)
"""
last_msg = messages[-1]
logging.debug(f"[multimodal_agent_strategy] last message: {last_msg}")
agent_selection = {
"user": "triage_agent",
"triage_agent": "multimodal_creator" if isinstance(last_msg, ToolCallSummaryMessage) else None,
"multimodal_creator": "main_assistant",
"main_assistant": "chat_closure",
}
selected_agent = agent_selection.get(last_msg.source)
if selected_agent is None and last_msg.source == "triage_agent":
selected_agent = "chat_closure"
if selected_agent:
logging.debug(f"[multimodal_agent_strategy] selected {selected_agent} agent")
return selected_agent
logging.debug("[multimodal_agent_strategy] selected None")
return None
self.selector_func = custom_selector_func
self.agents = [triage_agent, multimodal_creator, main_assistant, chat_closure]
return self._get_agents_configuration()