Autogen_v0.4/rag_agent/agnext_bot.py (257 lines of code) (raw):

from dataclasses import dataclass import time from typing import List from pydantic import BaseModel from autogen_core.components.models import ( AssistantMessage, ChatCompletionClient, AzureOpenAIChatCompletionClient, LLMMessage, SystemMessage, UserMessage, FunctionExecutionResult ) from autogen_core.components import RoutedAgent, message_handler from autogen_core.base import AgentId, MessageContext from autogen_core.components import ( DefaultTopicId, RoutedAgent, default_subscription, message_handler, type_subscription, Image ) from azure.identity import DefaultAzureCredential, get_bearer_token_provider import asyncio import uuid from dotenv import load_dotenv import os token_provider = get_bearer_token_provider( DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" ) load_dotenv() endpoint=os.environ["AZURE_OPENAI_ENDPOINT"] deployed_model = os.environ["DEPLOYMENT_NAME"] def get_model_client() -> AzureOpenAIChatCompletionClient: return AzureOpenAIChatCompletionClient( model=deployed_model, api_version="2024-02-01", azure_endpoint=endpoint, azure_ad_token_provider=token_provider, model_capabilities={ "vision":True, "function_calling":True, "json_output":True, "streaming":True, "max_tokens":1000, "temperature":0.0 } ) class Message(BaseModel): body: LLMMessage class StreamResponse(BaseModel): body: LLMMessage class GroupChatMessage(BaseModel): body: LLMMessage conversation_id: str class ToolResult(BaseModel): body: List[LLMMessage] conversation_id: str class ToolAgentMessage(BaseModel): body: LLMMessage conversation_id: str class FinalResult(BaseModel): body: LLMMessage class IntermediateResult(BaseModel): body: List[LLMMessage] conversation_id: str class ImageAnalysisMessage(BaseModel): body: LLMMessage conversation_id: str class ImageAnalysisResult(BaseModel): body: LLMMessage conversation_id: str common_system_message = SystemMessage(""" You can answer questions only about stock prices, machine learning, and azure services. If the question is outside of the above domain, say that you can't answer the question. Start your response with 'FinalResponse'. """) queue = asyncio.Queue[FinalResult]() llm_results_dict = {} condition = asyncio.Condition() from autogen_core.application import SingleThreadedAgentRuntime runtime = SingleThreadedAgentRuntime() @default_subscription @type_subscription("qa_agent") class QAAgent(RoutedAgent): def __init__(self, model_client: ChatCompletionClient) -> None: super().__init__("A Q&A Agent") self._model_client = model_client self._chat_history : List[LLMMessage] = [SystemMessage(""" You are AI assistant. You need to get the context needed to answer the question. Answer ONLY with the facts listed in the context below. <Context> </Context> Provide citations and sources using the pageNumber and docName in the context in the format [Source: docName, Page No.: pageNumber] Do not generate answers that don't use the context below. """)] @message_handler async def handle_any_group_message(self, message: ToolResult, ctx: MessageContext) -> None: print("\033[32m" + "-" * 20 + "\033[0m") print(f"Received by QA Agent:{message.body}") self._chat_history.extend(message.body) completion = await self._model_client.create(self._chat_history) #print(f"Intermediate Response:{completion.content}") im_result = message.body + [AssistantMessage(content=f"Response:{completion.content}", source="qa_agent")] await self.publish_message(IntermediateResult(body=im_result, conversation_id=message.conversation_id), DefaultTopicId()) @default_subscription @type_subscription("evaluator_agent") class EvalutorAgent(RoutedAgent): def __init__(self, model_client: ChatCompletionClient) -> None: super().__init__("A Evalutor Agent") self._model_client = model_client self._chat_history : List[LLMMessage] = [SystemMessage( """ You are an Evaluator. You will be provided user question , context and Response. If the response is correct given the context and question, then format the response as Markdown and send it. If the response is not correct, state that answer cannot be provided. """)] @message_handler async def handle_any_group_message(self, message: IntermediateResult, ctx: MessageContext) -> None: print("\033[32m" + "-" * 20 + "\033[0m") print(f"Received by Evalutor Agent:{message.body}") self._chat_history.extend(message.body) completion = await self._model_client.create(self._chat_history) #print(completion.content) async with condition: llm_results_dict[message.conversation_id] = FinalResult(body=AssistantMessage(content=completion.content, source="evaluator_agent"), conversation_id=message.conversation_id) condition.notify_all() #await queue.put(FinalResult(body=AssistantMessage(content=completion.content, source="evaluator_agent"))) del self._chat_history[1:] from autogen_core.components.tool_agent import ToolAgent, tool_agent_caller_loop from autogen_core.components.tools import FunctionTool, Tool, ToolSchema @default_subscription @type_subscription("tool_use_agent") class ToolUseAgent(RoutedAgent): def __init__(self, model_client: AzureOpenAIChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None: super().__init__("An agent with tools") self._system_messages: List[LLMMessage] = [SystemMessage( """You are AI assistant. You should not answer the question directly. You only need to call the tool provided to get the context needed to answer the question. For stock prices, use the tool get_stock_price. For Autocad architecture, machine learning and azure services, use the tool retrieve_search_results. Do not answer the question directly. """)] self._model_client = model_client self._tool_schema = tool_schema self._tool_agent_id = AgentId(tool_agent_type, self.id.key) @message_handler async def handle_any_group_message(self, message: ToolAgentMessage, ctx: MessageContext) -> None: print("\033[32m" + "-" * 20 + "\033[0m") print(f"Received by tool_use_agent:{message.body}") session: List[LLMMessage] = self._system_messages + [message.body] messages = await tool_agent_caller_loop( self, tool_agent_id=self._tool_agent_id, model_client=self._model_client, input_messages=session, tool_schema=self._tool_schema, cancellation_token=ctx.cancellation_token, ) # Return the final response. #assert isinstance(messages[-2].content[0], FunctionExecutionResult) #print(f"\n{'-'*80}\n tool result: {messages[-1].content}") formated_message = f"<Context>\n Tool name is: {messages[0].content[0].name}, tool input is: {messages[0].content[0].arguments} and tool result is: {messages[-2].content[0].content}\n</Context>" self._tool_result_history = [message.body] + [UserMessage(content=formated_message, source="tool_use_agent")] #session.append(UserMessage(content=formated_message, source="tool_use_agent")) await runtime.send_message(ToolResult(body=self._tool_result_history, conversation_id=message.conversation_id), AgentId("qa_agent", "default")) @default_subscription @type_subscription("ia_agent") class ImageAnalysisAgent(RoutedAgent): def __init__(self, model_client: ChatCompletionClient) -> None: super().__init__("A Image Analysis Agent") self._model_client = model_client self._chat_history : List[LLMMessage] = [SystemMessage("""You are building and construction archtitect. You need to analyze the image and describe the details like layout, size""")] @message_handler async def handle_any_group_message(self, message: ImageAnalysisMessage, ctx: MessageContext) -> None: self._chat_history.extend([message.body]) completion = await self._model_client.create(self._chat_history) print(f"Image Analysis Response:{completion.content}") async with condition: llm_results_dict[message.conversation_id] = FinalResult(body=AssistantMessage(content=completion.content, source="ia_agent"), conversation_id=message.conversation_id) condition.notify_all() @default_subscription @type_subscription("group_chat_manager") class GroupChatManager(RoutedAgent): def __init__(self, participant_topic_types: List[str]) -> None: super().__init__("Group chat manager") self._num_rounds = 0 self._participant_topic_types = participant_topic_types self._chat_history: List[GroupChatMessage] = [] @message_handler async def handle_message(self, message: GroupChatMessage, ctx: MessageContext) -> None: print("\033[32m" + "-" * 20 + "\033[0m") print(f"Received by GroupChatManager:{message.body}") self._chat_history.append(message) assert isinstance(message.body, UserMessage) speaker_topic_type = self._participant_topic_types[self._num_rounds % len(self._participant_topic_types)] self._num_rounds += 1 #await self.publish_message(message, DefaultTopicId(type=speaker_topic_type)) await runtime.publish_message(ToolAgentMessage(body=message.body, conversation_id=message.conversation_id), DefaultTopicId()) import random from typing_extensions import Annotated async def get_stock_price(ticker: str, date: Annotated[str, "Date in YYYY/MM/DD"]) -> float: # Returns a random stock price for demonstration purposes. return random.uniform(10, 200) import search_helper tools: List[Tool] = [FunctionTool(get_stock_price, description="Get the stock price."), FunctionTool(search_helper.retrieve_search_results, description=""" The index_name for the retrieve_search_results tool should either be the following and nothing else: For Machine learning and Autocad Architecture related questions. index_name: aml_index_with_suggester search_query: user question or For Azure Services related questions like Azure Functions. index_name: vectest search_query: user question retrieve search results for user questions on machine learning and azure services.""")] async def register_agents(): await ToolAgent.register(runtime, "tool_executor_agent", lambda: ToolAgent("tool executor agent", tools)) await ToolUseAgent.register( runtime, "tool_use_agent", lambda: ToolUseAgent( get_model_client(), [tool.schema for tool in tools], "tool_executor_agent" ), ) await QAAgent.register( runtime, "qa_agent", lambda: QAAgent( get_model_client(), ), ) await EvalutorAgent.register( runtime, "evaluator_agent", lambda: EvalutorAgent( get_model_client(), ), ) await ImageAnalysisAgent.register( runtime, "ia_agent", lambda: ImageAnalysisAgent( get_model_client(), ), ) await GroupChatManager.register( runtime, "group_chat_manager", lambda: GroupChatManager( participant_topic_types=["qa_agent" , "tool_use_agent", "evaluator_agent", "ia_agent"] ), ) #asyncio.run(register_agents()) async def start_multiagent_chat(user_message: str, image_url: str = None) -> str: try: start_time = time.time() runtime.start() end_time = time.time() elapsed_time_ms = (end_time - start_time) * 1000 # Convert to milliseconds print(f"runtime.start() took {elapsed_time_ms:.2f} ms.") except Exception as e: print(f"Error starting runtime: {e}") conversation_id = str(uuid.uuid4()) if image_url: image_data = await Image.from_url(image_url) await runtime.send_message(ImageAnalysisMessage(body=UserMessage(content = [ user_message, image_data ], source="user"), conversation_id=conversation_id), AgentId("ia_agent", "default")) else: await runtime.publish_message( GroupChatMessage( body=UserMessage(content=user_message, source="User"), conversation_id=conversation_id ), DefaultTopicId(), ) #await runtime.send_message(StreamResponse(body=UserMessage(content=user_query3, source="user")), AgentId("qa_agent", "default")) #await runtime.stop_when_idle() group_chat_result = "" try: # Wait for a message in the queue, or you can use a timeout if needed #group_chat_result = (await queue.get()).body.content # clear queue async with condition: while conversation_id not in llm_results_dict: await condition.wait() group_chat_result = llm_results_dict[conversation_id].body.content del llm_results_dict[conversation_id] print(f"conversation_id delete: {conversation_id}") print("\033[35m" + "-" * 20 + "\033[0m") except Exception as e: # Handle any exception that may occur during the wait for the response print(f"Error retrieving message from queue: {e}") group_chat_result = "An error occurred while waiting for the response." await runtime.stop() return group_chat_result #print(asyncio.run(start_multiagent_chat("tell me about einstien.")))