orchestration/strategies/multimodal_agent_strategy.py (151 lines of code) (raw):

import base64 import json import logging import re from typing import Annotated, Sequence from pydantic import BaseModel from autogen_agentchat.agents import AssistantAgent, BaseChatAgent from autogen_agentchat.base._chat_agent import Response from autogen_agentchat.messages import ( AgentEvent, ChatMessage, MultiModalMessage, TextMessage, ToolCallSummaryMessage, ) from autogen_core import CancellationToken, Image from autogen_core.model_context import BufferedChatCompletionContext from autogen_core.tools import FunctionTool from connectors import BlobClient from tools import get_time, get_today_date, multimodal_vector_index_retrieve from tools.ragindex.types import MultimodalVectorIndexRetrievalResult from ..constants import Strategy from .base_agent_strategy import BaseAgentStrategy class MultimodalMessageCreator(BaseChatAgent): """ A custom agent that constructs a MultiModalMessage from the vector_index_retrieve_wrapper tool results. The actual tool call is done by another agent (retrieval_agent) in the same SelectorGroupChat. This agent simply scans the conversation for the tool's result, parses text + image URLs, and returns a MultiModalMessage. """ def __init__(self, name: str, system_prompt: str, model_context: BufferedChatCompletionContext): super().__init__( name=name, description="An agent that creates `MultiModalMessage` objects from the results of `vector_index_retrieve_wrapper`, executed by an `AssistantAgent` called `retrieval_agent`." ) self._last_multimodal_result = None self.system_prompt = system_prompt + "\n\n" self._model_context = model_context @property def produced_message_types(self): """ Return the message types this agent can produce. We produce a MultiModalMessage. """ return (MultiModalMessage,) async def on_messages( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> Response: """ Handles incoming messages to process vector index retrieval results and construct a MultiModalMessage response. """ retrieval_data = None # Iterate through messages in reverse to find the latest relevant tool output for msg in reversed(messages): if isinstance(msg, ToolCallSummaryMessage): try: msg_content = msg.content parsed_content = json.loads(msg_content) if "texts" in parsed_content or "images" in parsed_content: retrieval_data = parsed_content break except json.JSONDecodeError as e: logging.warning(f"Failed to parse message content as JSON: {e}") continue if not retrieval_data: # Fallback response when no relevant data is found fallback_msg = TextMessage( content="No vector retrieval data was found in the conversation.", source=self.name ) return Response(chat_message=fallback_msg) # Extract text and image data texts = retrieval_data.get("texts", []) image_urls_list = retrieval_data.get("images", []) captions_lists = retrieval_data.get("captions", []) # Combine text snippets into a single string combined_text = self.system_prompt + "\n\n".join(texts) if texts else "No text results" logging.debug(f"[multimodal_agent_strategy] combined_text: {combined_text}") # Fetch images from URLs image_objects = [] max_images = 50 # maximum number of images to process (Azure OpenaI GPT-4o limit) image_count = 0 document_count = 0 for image_urls_list in image_urls_list: # Assuming each item in image_urls is a list of URLs image_count = 0 for url in image_urls_list: # Iterate through each URL in the sublist if image_count >= max_images: logging.info(f"[multimodal_agent_strategy] Reached the maximum image limit of {max_images}. Stopping further downloads.") break # Stop processing more URLs try: # Initialize BlobClient with the blob URL blob_client = BlobClient(blob_url=url) logging.debug(f"[multimodal_agent_strategy] Initialized BlobClient for URL: {url}") # Download the blob data as bytes blob_data = blob_client.download_blob() logging.debug(f"[multimodal_agent_strategy] Downloaded blob data for URL: {url}") # Open the image using PIL base64_str = base64.b64encode(blob_data).decode('utf-8') pil_img = Image.from_base64(base64_str) logging.debug(f"[multimodal_agent_strategy] Opened image from URL: {url}") uri = re.sub(r'https://[^/]+\.blob\.core\.windows\.net', '', url) pil_img.filepath = uri logging.debug(f"[multimodal_agent_strategy] Filepath (uri): {uri}") pil_img.caption = captions_lists[document_count][image_count] if image_count < len(image_urls_list) else None # Append the PIL Image object to your list (modify as needed) image_objects.append(pil_img) image_count += 1 # Increment the counter logging.info(f"[multimodal_agent_strategy] Successfully loaded image from {url}") except Exception as e: logging.error(f"[multimodal_agent_strategy] Could not load image from {url}: {e}") image_count += 1 # Increment the image number for the next set of URLs document_count += 1 # Increment the document number for the next set of URLs # Construct and return the MultiModalMessage response multimodal_msg = MultiModalMessage( content=[combined_text, *image_objects], source=self.name ) return Response(chat_message=multimodal_msg) async def on_reset(self, cancellation_token: CancellationToken) -> None: """ Reset the agent state if needed. In this basic example, we clear the internal variable. """ self._last_multimodal_result = None class MultimodalAgentStrategy(BaseAgentStrategy): def __init__(self): super().__init__() self.strategy_type = Strategy.MULTIMODAL_RAG async def create_agents(self, history, client_principal=None, access_token=None, output_mode=None, output_format=None): """ Multimodal RAG creation strategy that creates the basic agents and registers functions. Parameters: - history: The conversation history, which will be summarized to provide context for the assistant's responses. Returns: - agent_configuration: A dictionary that includes the agents team, default model client, termination conditions and selector function. Note: To use a different model for an specific agent, instantiate a separate AzureOpenAIChatCompletionClient and assign it instead of using self._get_model_client(). """ # Model Context shared_context = await self._get_model_context(history) # Wrapper Functions for Tools async def vector_index_retrieve_wrapper( input: Annotated[str, "An optimized query string based on the user's ask and conversation history, when available"] ) -> MultimodalVectorIndexRetrievalResult: return await multimodal_vector_index_retrieve(input, self._generate_security_ids(client_principal)) vector_index_retrieve_tool = FunctionTool( vector_index_retrieve_wrapper, name="vector_index_retrieve", description="Performs a vector search using Azure AI Search fetching text and related images get relevant sources for answering the user's query." ) # Agents ## Triage Agent triage_prompt = await self._read_prompt("triage_agent") triage_agent = AssistantAgent( name="triage_agent", system_message=triage_prompt, model_client=self._get_model_client(), tools=[vector_index_retrieve_tool], reflect_on_tool_use=False, model_context=shared_context ) ## Multimodal Message Creator multimodal_rag_message_prompt = await self._read_prompt("multimodal_rag_message") multimodal_creator = MultimodalMessageCreator(name="multimodal_creator", system_prompt=multimodal_rag_message_prompt, model_context=shared_context) ## Assistant Agent main_assistant_prompt = await self._read_prompt("main_assistant") main_assistant = AssistantAgent( name="main_assistant", system_message=main_assistant_prompt, model_client=self._get_model_client(), reflect_on_tool_use=True ) ## Chat Closure Agent chat_closure = await self._create_chat_closure_agent(output_format, output_mode) # Agent Configuration # Optional: Override the termination condition for the assistant. Set None to disable each termination condition. # self.max_rounds = int(os.getenv('MAX_ROUNDS', 8)) # self.terminate_message = "TERMINATE" def custom_selector_func(messages): """ Selects the next agent based on the source of the last message. Transition Rules: user -> triage_agent triage_agent (ToolCallSummaryMessage) -> multimodal_creator multimodal_creator -> assistant Other -> None (SelectorGroupChat will handle transition) """ last_msg = messages[-1] logging.debug(f"[multimodal_agent_strategy] last message: {last_msg}") agent_selection = { "user": "triage_agent", "triage_agent": "multimodal_creator" if isinstance(last_msg, ToolCallSummaryMessage) else None, "multimodal_creator": "main_assistant", "main_assistant": "chat_closure", } selected_agent = agent_selection.get(last_msg.source) if selected_agent is None and last_msg.source == "triage_agent": selected_agent = "chat_closure" if selected_agent: logging.debug(f"[multimodal_agent_strategy] selected {selected_agent} agent") return selected_agent logging.debug("[multimodal_agent_strategy] selected None") return None self.selector_func = custom_selector_func self.agents = [triage_agent, multimodal_creator, main_assistant, chat_closure] return self._get_agents_configuration()