1_agentic-design-ptn/01_reflection/AutoGen/03.1_adaptive-rag-eval-sdk.ipynb (1,377 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "c26ab996",
"metadata": {},
"source": [
"# Adaptive RAG\n",
"\n",
"----\n",
"\n",
"Adaptive RAG predicts the **complexity of the input question** using a SLM/LLM and selects an appropriate processing workflow accordingly.\n",
"\n",
"- **Very simple question (No Retrieval)**: Generates answers without RAG.\n",
"- **Simple question (Single-shot RAG)**: Efficiently generates answers through a single-step search and generation.\n",
"- **Complex question (Iterative RAG)**: Provides accurate answers to complex questions through repeated multi-step search and generation.\n",
"\n",
"\n",
"Adaptive-RAG, Self-RAG, and Corrective RAG are similar approach, but they have different focuses.\n",
"\n",
"- **Adaptive-RAG**: Dynamically selects appropriate retrieval and generation strategies based on the complexity of the question.\n",
"- **Self-RAG**: The model determines the need for retrieval on its own, performs retrieval when necessary, and improves the quality through self-reflection on the generated answers.\n",
"- **Corrective RAG**: Evaluates the quality of retrieved documents, and performs additional retrievals such as web searches to supplement the information if the reliability is low.\n",
"\n",
"**Reference**\n",
"\n",
"- [Adaptive-RAG paper](https://arxiv.org/abs/2403.14403) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6458235",
"metadata": {},
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"import os\n",
"import json\n",
"from azure.core.credentials import AzureKeyCredential\n",
"from azure.identity import DefaultAzureCredential\n",
"from azure.search.documents import SearchClient\n",
"from azure.search.documents.models import VectorizableTextQuery\n",
"from azure.ai.evaluation import (\n",
" GroundednessEvaluator,\n",
" RelevanceEvaluator,\n",
" RetrievalEvaluator,\n",
")\n",
"from autogen_ext.models.openai import AzureOpenAIChatCompletionClient\n",
"from autogen_core.models import (\n",
" ChatCompletionClient,\n",
" SystemMessage,\n",
" UserMessage,\n",
" AssistantMessage,\n",
")\n",
"from autogen_core import (\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" TopicId,\n",
" message_handler,\n",
" type_subscription,\n",
")\n",
"from pydantic import BaseModel\n",
"from typing import List\n",
"from dataclasses import dataclass\n",
"\n",
"\n",
"load_dotenv(override=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05d7dfed",
"metadata": {},
"outputs": [],
"source": [
"# Get the environment variables\n",
"azure_ai_search_endpoint = os.getenv(\"AZURE_AI_SEARCH_ENDPOINT\")\n",
"search_credential = (\n",
" AzureKeyCredential(os.getenv(\"AZURE_AI_SEARCH_API_KEY\", \"\"))\n",
" if len(os.getenv(\"AZURE_AI_SEARCH_API_KEY\", \"\")) > 0\n",
" else DefaultAzureCredential()\n",
")\n",
"index_name = os.getenv(\"AZURE_SEARCH_INDEX_NAME\", \"hotels-sample-index\")\n",
"\n",
"azure_openai_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\")\n",
"azure_openai_key = (\n",
" os.getenv(\"AZURE_OPENAI_API_KEY\", \"\")\n",
" if len(os.getenv(\"AZURE_OPENAI_API_KEY\", \"\")) > 0\n",
" else None\n",
")\n",
"azure_openai_chat_deployment_name = os.getenv(\"AZURE_OPENAI_CHAT_DEPLOYMENT_NAME\")\n",
"azure_openai_embedding_deployment_name = os.getenv(\n",
" \"AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME\", \"text-embedding-ada-002\"\n",
")\n",
"azure_penai_api_version = os.getenv(\"OPENAI_API_VERSION\", \"2024-06-01\")\n",
"\n",
"bing_subscription_key = (\n",
" os.getenv(\"BING_SUBSCRIPTION_KEY\", \"\")\n",
" if len(os.getenv(\"BING_SUBSCRIPTION_KEY\", \"\")) > 0\n",
" else None\n",
")\n",
"\n",
"model_config = {\n",
" \"azure_endpoint\": azure_openai_endpoint,\n",
" \"api_key\": azure_openai_key,\n",
" \"azure_deployment\": azure_openai_chat_deployment_name,\n",
" \"api_version\": azure_penai_api_version,\n",
" \"type\": \"azure_openai\",\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "5ec192a9",
"metadata": {},
"source": [
"## ๐งช Step 1. Test and Construct each module\n",
"---\n",
"\n",
"Before building the entire the graph pipeline, we will test and construct each module separately.\n",
"\n",
"- **IntentRouter**\n",
"- **SearchClient(Retrieval)**\n",
"- **Retrieval Grader**\n",
"- **Question Re-writer**\n",
"- **Answer Generator**\n",
"- **Groundedness Evaluator**\n",
"- **Relevance Evaluator**\n",
"- **Keyword Re-writer**\n",
"- **Web Search Tool**\n"
]
},
{
"cell_type": "markdown",
"id": "083d39ad",
"metadata": {},
"source": [
"### Define your LLM\n",
"\n",
"This hands-on only uses the `gpt-4o-mini`, but you can utilize multiple models in the pipeline."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18d3a8f8",
"metadata": {},
"outputs": [],
"source": [
"# aoai_client = AzureOpenAI(\n",
"# azure_endpoint=azure_openai_endpoint,\n",
"# api_key=azure_openai_key,\n",
"# api_version=openai_api_version,\n",
"# )\n",
"\n",
"# This is not the same object as the one above. This is the client that is used to interact with the Azure OpenAI Chat API.\n",
"autogen_aoai_client = AzureOpenAIChatCompletionClient(\n",
" azure_endpoint=azure_openai_endpoint,\n",
" model=azure_openai_chat_deployment_name,\n",
" api_version=azure_penai_api_version,\n",
" api_key=azure_openai_key,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "3e4c24d3",
"metadata": {},
"source": [
"### Intent Router\n",
"\n",
"Construct a `intent_router` agent to analyze the intent from user's query to route the query to the appropriate module."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4deb3a06",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel, Field\n",
"from typing import List\n",
"from enum import Enum\n",
"\n",
"\n",
"class IntentType(str, Enum):\n",
" LLM = \"LLM\"\n",
" RAG = \"RAG\"\n",
" websearch = \"websearch\"\n",
"\n",
"\n",
"class IntentResponse(BaseModel):\n",
" intent_type: IntentType = Field(..., description=\"Processing status\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b10654c",
"metadata": {},
"outputs": [],
"source": [
"# query=\"how are you today?\"\n",
"query = \"Can you recommend a few hotels with complimentary breakfast?\"\n",
"# query=\"Can you recommend the newest Openings Hotels in Manhattan Midtown 2025?\"\n",
"\n",
"# This prompt provides instructions to the model\n",
"INTENT_ROUTER_PROMPT = \"\"\"\n",
"You are an expert at routing a user question to LLM or RAG or websearch.\n",
" The LLM covers casual topic such as greeting, small talks.\n",
" Use the LLM for questions on casual topics.\n",
" The RAG contains documents related to hotel information in New York until Aug, 2024.\n",
" Use the RAG for questions on the hotel related topics. For all else, websearch.\n",
" response inent_type such as LLM, RAG, or websearch.\n",
"Query: {query}\n",
"\"\"\"\n",
"\n",
"# Send the search results and the query to the LLM to generate a response based on the prompt.\n",
"response = await autogen_aoai_client.create(\n",
" messages=[\n",
" UserMessage(content=INTENT_ROUTER_PROMPT.format(query=query), source=\"user\"),\n",
" ],\n",
" extra_create_args={\"response_format\": IntentResponse},\n",
")\n",
"\n",
"\n",
"# Here is the response from the chat model.\n",
"print(response.content)"
]
},
{
"cell_type": "markdown",
"id": "cdf92a96",
"metadata": {},
"source": [
"### Construct Retrieval Chain based on PDF\n",
"- We use the hotels-sample-index, which can be created in minutes and runs on any search service tier. This index is created by a wizard using built-in sample data."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4121f7b6",
"metadata": {},
"outputs": [],
"source": [
"azure_ai_search_endpoint = os.getenv(\"AZURE_AI_SEARCH_ENDPOINT\")\n",
"azure_search_admin_key = os.getenv(\"AZURE_AI_SEARCH_API_KEY\", \"\")\n",
"search_client = SearchClient(\n",
" endpoint=azure_ai_search_endpoint,\n",
" index_name=index_name,\n",
" credential=AzureKeyCredential(azure_search_admin_key),\n",
" semantic_configuration_name=\"my-semantic-config\",\n",
")\n",
"\n",
"# Query is the question being asked. It's sent to the search engine and the LLM.\n",
"\n",
"fields = \"descriptionVector\" # TODO: Check if this is the correct field name\n",
"# don't use exhaustive search for large indexes\n",
"vector_query = VectorizableTextQuery(\n",
" text=query, k_nearest_neighbors=2, fields=fields, exhaustive=True\n",
")\n",
"\n",
"# Search results are created by the search client.\n",
"# Search results are composed of the top 3 results and the fields selected from the search index.\n",
"# Search results include the top 3 matches to your query.\n",
"search_results = search_client.search(\n",
" search_text=query,\n",
" vector_queries=[vector_query],\n",
" select=\"Description,HotelName,Tags\",\n",
" top=3,\n",
")\n",
"sources_formatted = \"\\n\".join(\n",
" [\n",
" f'{document[\"HotelName\"]}:{document[\"Description\"]}:{document[\"Tags\"]}'\n",
" for document in search_results\n",
" ]\n",
")\n",
"\n",
"print(sources_formatted)"
]
},
{
"cell_type": "markdown",
"id": "65ec10e5",
"metadata": {},
"source": [
"### Question-Retrieval Grader\n",
"\n",
"Construct a retrieval grader that evaluates the relevance of the retrieved documents to the input question. The retrieval grader should take the input question and the retrieved documents as input and output a relevance score for each document.<br>\n",
"Note that the retrieval grader should be able to handle **multiple documents** as input."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c85e4c2",
"metadata": {},
"outputs": [],
"source": [
"retrieval_eval = RetrievalEvaluator(model_config)\n",
"\n",
"query_response = dict(query=query, context=sources_formatted)\n",
"\n",
"relevance_score = retrieval_eval(**query_response)\n",
"print(relevance_score)\n",
"relevance_score[\"retrieval\"]"
]
},
{
"cell_type": "markdown",
"id": "8209a547",
"metadata": {},
"source": [
"### Question Re-writer\n",
"\n",
"Construct a `question_rewriter` node to rewrite the question based on the retrieved documents and the generated answer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6f4bafb",
"metadata": {},
"outputs": [],
"source": [
"query = \"Can you recommend a few hoels with complimentary breakfast?\"\n",
"\n",
"# This prompt provides instructions to the model\n",
"REWRITE_PROMPT = \"\"\"\n",
"You a question re-writer that converts an input question to a better version that is optimized\n",
"for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.\n",
"Query: {query}\n",
"\"\"\"\n",
"\n",
"# Send the search results and the query to the LLM to generate a response based on the prompt.\n",
"response = await autogen_aoai_client.create(\n",
" messages=[\n",
" UserMessage(content=REWRITE_PROMPT.format(query=query), source=\"user\"),\n",
" ]\n",
")\n",
"\n",
"\n",
"# Here is the response from the chat model.\n",
"print(response.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f07b8c0",
"metadata": {},
"outputs": [],
"source": [
"query = response.content"
]
},
{
"cell_type": "markdown",
"id": "702b7f87",
"metadata": {},
"source": [
"### Answer Generator\n",
"\n",
"Construct a LLM Generation node. This is a Naive RAG chain that generates an answer based on the retrieved documents. \n",
"\n",
"We recommend you to use more advanced RAG chain for production"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2168105c",
"metadata": {},
"outputs": [],
"source": [
"class HotelInfo(BaseModel):\n",
" hotel_name: str\n",
" description: str\n",
"\n",
"\n",
"class RecommendationList(BaseModel):\n",
" recommendation: List[HotelInfo]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d080f695",
"metadata": {},
"outputs": [],
"source": [
"# This prompt provides instructions to the model\n",
"GROUNDED_PROMPT = \"\"\"\n",
"You are a friendly assistant that recommends hotels based on activities and amenities.\n",
"Answer the query using only the context provided below in a friendly and concise bulleted manner.\n",
"Answer ONLY with the facts listed in the list of context below.\n",
"If there isn't enough information below, say you don't know.\n",
"Generate a response that includes the top 3 results.\n",
"Do not generate answers that don't use the context below.\n",
"Query: {query}\n",
"Context:\\n{context}\n",
"\"\"\"\n",
"\n",
"# Send the search results and the query to the LLM to generate a response based on the prompt.\n",
"response = await autogen_aoai_client.create(\n",
" messages=[\n",
" UserMessage(\n",
" content=GROUNDED_PROMPT.format(query=query, context=sources_formatted),\n",
" source=\"user\",\n",
" ),\n",
" ],\n",
" extra_create_args={\"response_format\": RecommendationList},\n",
")\n",
"\n",
"response_content = json.loads(response.content)\n",
"for recommendation in response_content[\"recommendation\"]:\n",
" print(recommendation)"
]
},
{
"cell_type": "markdown",
"id": "b167e23e",
"metadata": {},
"source": [
"### Groundedness Evaluator\n",
"\n",
"Construct a `groundedness_grader` node to evaluate the **hallucination** of the generated answer based on the retrieved documents.<br>\n",
"\n",
"`yes` means the answer is relevant to the retrieved documents, and `no` means the answer is not relevant to the retrieved documents."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63d55c28",
"metadata": {},
"outputs": [],
"source": [
"groundedness_eval = GroundednessEvaluator(model_config)\n",
"\n",
"query_response = dict(query=query, context=sources_formatted, response=response_content)\n",
"\n",
"groundedness_score = groundedness_eval(**query_response)\n",
"print(groundedness_score)"
]
},
{
"cell_type": "markdown",
"id": "b8a6c80a",
"metadata": {},
"source": [
"### Relevance Evaluator\n",
"\n",
"Construct a `relevance_grader` node to evaluate the relevance of the generated answer to the question.<br>\n",
"`yes` means the answer is relevant to the question, and `no` means the answer is not relevant to the question."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5b80ab5",
"metadata": {},
"outputs": [],
"source": [
"relevance_eval = RelevanceEvaluator(model_config)\n",
"\n",
"query_response = dict(query=query, response=response_content)\n",
"\n",
"relevance_score = relevance_eval(**query_response)\n",
"print(relevance_score)"
]
},
{
"cell_type": "markdown",
"id": "fd05275c",
"metadata": {},
"source": [
"### Keyword Re-writer\n",
"\n",
"Construct a `keyword_rewriter` agent to rewrite the question as the search keyword."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "806f6d04",
"metadata": {},
"outputs": [],
"source": [
"# This prompt provides instructions to the model\n",
"KEYWORD_REWRITE_PROMPT = \"\"\"\n",
"You a keyword re-writer that converts an input question to a better version that is optimized for search. \n",
"Generate search keyword from a user query \n",
"to be more specific, detailed, and likely to retrieve relevant information, allowing for a more accurate response through web search.\n",
"Don't include the additional context from the user question.\n",
"\n",
"Query: {query}\n",
"Revised web search query:\n",
"\"\"\"\n",
"\n",
"# Send the search results and the query to the LLM to generate a response based on the prompt.\n",
"response = await autogen_aoai_client.create(\n",
" messages=[\n",
" UserMessage(content=KEYWORD_REWRITE_PROMPT.format(query=query), source=\"user\"),\n",
" ]\n",
")\n",
"\n",
"\n",
"# Here is the response from the chat model.\n",
"print(response.content)\n",
"query = response.content"
]
},
{
"cell_type": "markdown",
"id": "00a7d8b2",
"metadata": {},
"source": [
"### Web Search Tool\n",
"\n",
"Web search tool is used to enhance the context. <br>\n",
"\n",
"It is used when all the documents do not meet the relevance threshold or the evaluator is not confident."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5444bcf",
"metadata": {},
"outputs": [],
"source": [
"from azure_genai_utils.tools import BingSearch\n",
"\n",
"WEB_SEARCH_FORMAT_OUTPUT = False\n",
"\n",
"web_search_tool = BingSearch(\n",
" max_results=3,\n",
" locale=\"en-US\",\n",
" include_news=False,\n",
" include_entity=False,\n",
" format_output=WEB_SEARCH_FORMAT_OUTPUT,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "595b71f8",
"metadata": {},
"outputs": [],
"source": [
"query = \"Newest Openings Hotels in NYC 2024 2025?\"\n",
"results = web_search_tool.invoke({\"query\": query})\n",
"print(results[0].get(\"content\", \"No content\"))"
]
},
{
"cell_type": "markdown",
"id": "f08d720d",
"metadata": {},
"source": [
"<br>\n",
"\n",
"## ๐งช Step 2. Define the Agentic Architecture\n",
"- Before building the agentic pipeline, we need to design the message, topic, agent and message routing logic. \n",
"- You should define the terminate condition for the pipeline.\n",
"\n",
"### Message, Topic, Agent Definition\n",
"\n",
"```markdown\n",
"```python\n",
"\n",
"# Message Definition\n",
"@dataclass\n",
"class Message:\n",
" query: str = None\n",
" context: str = None\n",
" response: str = None\n",
" source: str = None\n",
"\n",
"\n",
"# Topic Definition\n",
"user_query_topic_type = \"UserQuery\"\n",
"rag_grader_topic_type = \"RagGraderAgent\"\n",
"query_rewrite_topic_type = \"QueryRewriteAgent\"\n",
"generate_topic_type = \"GenerateAgent\"\n",
"eval_topic_type = \"EvalAgent\"\n",
"keyword_rewrite_topic_type = \"KeywordRewriteAgent\"\n",
"web_search_topic_type = \"WebSearchAgent\"\n",
"user_topic_type = \"UserAgent\"\n",
"\n",
"# Agent Definition\n",
"class IntentRouterAgent(RoutedAgent):\n",
"class RAGGraderAgent(RoutedAgent):\n",
"class QueryRewriteAgent(RoutedAgent):\n",
"class GenerateAgent(RoutedAgent):\n",
"class EvalAgent(RoutedAgent):\n",
"class KeywordRewriteAgent(RoutedAgent):\n",
"class WebSearchAgent(RoutedAgent):\n",
"class UserAgent(RoutedAgent):\n",
"\n",
"\n",
"```\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "e378d9fd",
"metadata": {},
"source": [
"Visualizing the abstract architecture of the pipeline will help you understand the message flow and the agent's role in the pipeline."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52aa07a7",
"metadata": {},
"outputs": [],
"source": [
"from azure_genai_utils.graphs import visualize_agents\n",
"\n",
"agents = [\n",
" \"Start\",\n",
" \"IntentRouterAgent\",\n",
" \"RAGGraderAgent\",\n",
" \"QueryRewriteAgent\",\n",
" \"GenerateAgent\",\n",
" \"EvalAgent\",\n",
" \"KeywordRewriteAgent\",\n",
" \"WebSearchAgent\",\n",
" \"UserAgent\",\n",
"]\n",
"interactions = [\n",
" (\"Start\", \"IntentRouterAgent\"),\n",
" (\"IntentRouterAgent\", \"GenerateAgent\", \"Generates Response\"),\n",
" (\"IntentRouterAgent\", \"RAGGraderAgent\", \"Retrieval Context\"),\n",
" (\"IntentRouterAgent\", \"KeywordRewriteAgent\", \"Rewrites as keyword for bing search\"),\n",
" (\"GenerateAgent\", \"EvalAgent\"),\n",
" (\"EvalAgent\", \"UserAgent\"),\n",
" (\"RAGGraderAgent\", \"GenerateAgent\", \"Generates Response\"),\n",
" (\"RAGGraderAgent\", \"QueryRewriteAgent\", \"Rewrites Query\"),\n",
" (\"QueryRewriteAgent\", \"GenerateAgent\"),\n",
" (\"KeywordRewriteAgent\", \"WebSearchAgent\"),\n",
" (\"WebSearchAgent\", \"GenerateAgent\"),\n",
" # (\"EvalAgent\", \"IntentRouterAgent\"),\n",
"]\n",
"\n",
"visualize_agents(agents, interactions)"
]
},
{
"cell_type": "markdown",
"id": "2c98b3eb",
"metadata": {},
"source": [
"This is an example of visualized pipeline\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87d1f724",
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Message:\n",
" intent: str = None\n",
" query: str = None\n",
" context: str = None\n",
" response: str = None\n",
" source: str = None\n",
"\n",
" def set_source(self, source: str) -> \"Message\":\n",
" self.source = source\n",
" return self\n",
"\n",
"\n",
"# Topic Definition\n",
"user_query_topic_type = \"UserQuery\"\n",
"rag_grader_topic_type = \"RagGraderAgent\"\n",
"query_rewrite_topic_type = \"QueryRewriteAgent\"\n",
"generate_topic_type = \"GenerateAgent\"\n",
"eval_topic_type = \"EvalAgent\"\n",
"keyword_rewrite_topic_type = \"KeywordRewriteAgent\"\n",
"web_search_topic_type = \"WebSearchAgent\"\n",
"user_topic_type = \"UserAgent\"\n",
"\n",
"query = \"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ee78e07",
"metadata": {},
"outputs": [],
"source": [
"# This prompt provides instructions to the model\n",
"INTENT_ROUTER_PROMPT = \"\"\"\n",
" The LLM covers casual topic such as greeting, small talks and basic information. \n",
" Use the LLM for questions on IT related topics such as what is the history of microsoft, what is deep learning?, how can I learn about Gen AI.\n",
" The RAG contains documents related to hotel information in New York until Aug, 2024.\n",
" Use the RAG for questions on the hotel related topics. For all else, websearch.\n",
" response inent_type such as LLM, RAG, or websearch.\n",
"Query: {query}\n",
"\"\"\"\n",
"\n",
"\n",
"@type_subscription(topic_type=user_query_topic_type)\n",
"class IntentRouterAgent(RoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
" super().__init__(\"Query Rewrite Agent\")\n",
" self._system_message = SystemMessage(\n",
" content=(\n",
" \"\"\"\n",
" You are an expert at routing a user question to LLM or RAG or websearch.\n",
" \"\"\"\n",
" )\n",
" )\n",
" self._model_client = model_client\n",
"\n",
" @message_handler\n",
" async def handle_message(self, message: Message, ctx: MessageContext) -> None:\n",
" print(f\"\\n{'-'*80}\\n{self.id.type} received a message:\\n\")\n",
"\n",
" llm_result = await self._model_client.create(\n",
" messages=[\n",
" self._system_message,\n",
" UserMessage(\n",
" content=INTENT_ROUTER_PROMPT.format(query=message.query),\n",
" source=message.source,\n",
" ),\n",
" ],\n",
" extra_create_args={\"response_format\": IntentResponse},\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" response_content = json.loads(llm_result.content)\n",
" print(response_content)\n",
"\n",
" if response_content[\"intent_type\"] == \"LLM\":\n",
" await self.publish_message(\n",
" Message(\n",
" intent=response_content[\"intent_type\"],\n",
" query=message.query,\n",
" context=\"no context\",\n",
" source=message.source,\n",
" ),\n",
" topic_id=TopicId(type=generate_topic_type, source=message.source),\n",
" )\n",
" elif response_content[\"intent_type\"] == \"RAG\":\n",
" await self.publish_message(\n",
" Message(\n",
" intent=response_content[\"intent_type\"],\n",
" query=message.query,\n",
" source=message.source,\n",
" ),\n",
" topic_id=TopicId(type=rag_grader_topic_type, source=message.source),\n",
" )\n",
" elif response_content[\"intent_type\"] == \"websearch\":\n",
" await self.publish_message(\n",
" Message(\n",
" intent=response_content[\"intent_type\"],\n",
" query=message.query,\n",
" source=message.source,\n",
" ),\n",
" topic_id=TopicId(\n",
" type=keyword_rewrite_topic_type, source=message.source\n",
" ),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "883dde68",
"metadata": {},
"outputs": [],
"source": [
"@type_subscription(topic_type=rag_grader_topic_type)\n",
"class RAGGraderAgent(RoutedAgent):\n",
"\n",
" def __init__(\n",
" self,\n",
" azure_ai_search_endpoint: str,\n",
" azure_search_admin_key: str,\n",
" index_name: str,\n",
" retrieval_evaluator: RetrievalEvaluator,\n",
" ) -> None:\n",
"\n",
" super().__init__(\"RAG Grader Agent\")\n",
" self.index_name = index_name\n",
" self.azure_ai_search_endpoint = azure_ai_search_endpoint\n",
" self.azure_search_admin_key = azure_search_admin_key\n",
" self.retrieval_evaluator = retrieval_evaluator\n",
"\n",
" def config_search(self) -> SearchClient:\n",
" service_endpoint = self.azure_ai_search_endpoint\n",
" key = self.azure_search_admin_key\n",
" index_name = self.index_name\n",
" credential = AzureKeyCredential(key)\n",
" return SearchClient(\n",
" endpoint=service_endpoint, index_name=index_name, credential=credential\n",
" )\n",
"\n",
" async def do_search(self, query: str) -> str:\n",
" \"\"\"Search indexed data using Azure Cognitive Search with vector-based queries.\"\"\"\n",
" aia_search_client = self.config_search()\n",
"\n",
" fields = \"descriptionVector\" # TODO: Check if this is the correct field name\n",
" # don't use exhaustive search for large indexes\n",
" vector_query = VectorizableTextQuery(\n",
" text=query, k_nearest_neighbors=1, fields=fields, exhaustive=True\n",
" )\n",
"\n",
" search_results = aia_search_client.search(\n",
" search_text=query,\n",
" vector_queries=[vector_query],\n",
" select=[\n",
" \"Description,HotelName,Tags\"\n",
" ], # TODO: Check if these are the correct field names\n",
" top=5, # TODO: Check if this is the correct number of results\n",
" )\n",
" answer = \"\\n\".join(\n",
" [\n",
" f'{document[\"HotelName\"]}:{document[\"Description\"]}:{document[\"Tags\"]}'\n",
" for document in search_results\n",
" ]\n",
" )\n",
" return answer\n",
"\n",
" @message_handler\n",
" async def handle_message(self, message: Message, ctx: MessageContext) -> None:\n",
" print(f\"\\n{'-'*80}\\n{self.id.type} received a message:\\n\")\n",
"\n",
" context_from_ai_search = await self.do_search(message.query)\n",
" print(context_from_ai_search)\n",
"\n",
" query_response = dict(query=query, context=context_from_ai_search)\n",
"\n",
" retrieval_score = self.retrieval_evaluator(**query_response)\n",
"\n",
" print(f\"retrieval_score: {retrieval_score['retrieval']}\")\n",
"\n",
" if retrieval_score[\"retrieval\"] >= 3.0:\n",
" await self.publish_message(\n",
" Message(\n",
" intent=message.intent,\n",
" query=message.query,\n",
" context=context_from_ai_search,\n",
" source=message.source,\n",
" ),\n",
" topic_id=TopicId(type=generate_topic_type, source=message.source),\n",
" )\n",
" elif retrieval_score[\"retrieval\"] < 3.0:\n",
" await self.publish_message(\n",
" Message(\n",
" intent=message.intent,\n",
" query=message.query,\n",
" context=context_from_ai_search,\n",
" source=message.source,\n",
" ),\n",
" topic_id=TopicId(type=query_rewrite_topic_type, source=message.source),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e1edfc9",
"metadata": {},
"outputs": [],
"source": [
"REWRITE_PROMPT = \"\"\"\n",
"You a question re-writer that converts an input question to a better version that is optimized\n",
"for vectorstore retrieval. Look at the input and try to rewrite for hotel information service about the underlying semantic intent / meaning.\n",
"follow the scope of the query and don't add more additional condition. \n",
"Query: {query}\n",
"\"\"\"\n",
"\n",
"\n",
"@type_subscription(topic_type=query_rewrite_topic_type)\n",
"class QueryRewriteAgent(RoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
" super().__init__(\"Query Rewrite Agent\")\n",
" self._system_message = SystemMessage(\n",
" content=(\n",
" \"\"\"\n",
" You are an helper agent that can rewrite the query.\n",
" \"\"\"\n",
" )\n",
" )\n",
" self._model_client = model_client\n",
"\n",
" @message_handler\n",
" async def handle_message(self, message: Message, ctx: MessageContext) -> None:\n",
" print(f\"\\n{'-'*80}\\n{self.id.type} received a message:\\n\")\n",
" print(f\"\\n{'-'*80}\\n{message.query} received query:\\n\")\n",
" llm_result = await self._model_client.create(\n",
" messages=[\n",
" self._system_message,\n",
" UserMessage(\n",
" content=REWRITE_PROMPT.format(query=message.query),\n",
" source=message.source,\n",
" ),\n",
" ],\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" response = llm_result.content\n",
" print(f\"Rewrite query:\\n{response}\")\n",
" await self.publish_message(\n",
" Message(\n",
" intent=message.intent,\n",
" query=response,\n",
" context=message.context,\n",
" source=message.source,\n",
" ),\n",
" topic_id=TopicId(type=generate_topic_type, source=message.source),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93e9a219",
"metadata": {},
"outputs": [],
"source": [
"# This prompt provides instructions to the model\n",
"GROUNDED_PROMPT = \"\"\"\n",
"Read the context carefully and answer the query in a friendly, concise, bulleted format.\n",
"\n",
"Follow these rules based on the Intent:\n",
"\n",
"LLM:\n",
"Provide the answer in a casual, human-like chatbot style.\n",
"Conclude with the statement: “I am saying this from my basic knowledge.”\n",
"websearch or RAG:\n",
"Use only the facts explicitly stated in the context.\n",
"Conclude with the statement: “I am saying this based on the {intent}”\n",
"If the context does not contain enough information to answer, respond with:\n",
"\n",
"“I don’t know.”\n",
"Do not include any information outside the provided context.\n",
"\n",
"Intent: {intent}\n",
"Query: {query}\n",
"Context:\\n{context}\n",
"\"\"\"\n",
"\n",
"\n",
"@type_subscription(topic_type=generate_topic_type)\n",
"class GenerateAgent(RoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
" super().__init__(\"Generate Agent\")\n",
" self._system_message = SystemMessage(\n",
" content=(\n",
" \"\"\"\n",
" You are a friendly assistant that recommends hotels based on activities and amenities.\n",
" \"\"\"\n",
" )\n",
" )\n",
" self._model_client = model_client\n",
"\n",
" @message_handler\n",
" async def handle_message(self, message: Message, ctx: MessageContext) -> None:\n",
" print(f\"\\n{'-'*80}\\n{self.id.type} received a message:\\n\")\n",
" llm_result = await self._model_client.create(\n",
" messages=[\n",
" self._system_message,\n",
" UserMessage(\n",
" content=GROUNDED_PROMPT.format(\n",
" intent=message.intent,\n",
" query=message.query,\n",
" context=message.context,\n",
" ),\n",
" source=message.source,\n",
" ),\n",
" ],\n",
" extra_create_args=(\n",
" {\"response_format\": RecommendationList}\n",
" if message.intent == \"RAG\"\n",
" else {}\n",
" ),\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
"\n",
" response_content = llm_result.content\n",
"\n",
" print(f\"Generated response:\\n{response_content}\")\n",
" if message.intent == \"LLM\":\n",
" await self.publish_message(\n",
" AssistantMessage(content=response_content, source=message.source),\n",
" topic_id=TopicId(type=user_topic_type, source=message.source),\n",
" )\n",
" else:\n",
" await self.publish_message(\n",
" Message(\n",
" intent=message.intent,\n",
" query=message.query,\n",
" context=message.context,\n",
" response=response_content,\n",
" source=message.source,\n",
" ),\n",
" topic_id=TopicId(type=eval_topic_type, source=message.source),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48289751",
"metadata": {},
"outputs": [],
"source": [
"INCORRECT_ANSWER = \"\"\"\n",
"Hello, and thank you for bringing this to our attention! I may have provided an inaccurate or misleading response, and I sincerely apologize for the confusion.\n",
"As an AI, I aim to deliver helpful and accurate information, but sometimes I might misinterpret or generate an incorrect response. Your feedback is invaluable and helps me improve.\n",
"\n",
"If you'd like, feel free to share more details or clarify your question, and I’ll do my best to assist you further. Thank you for your understanding and patience! ๐\n",
"\"\"\"\n",
"\n",
"\n",
"@type_subscription(topic_type=eval_topic_type)\n",
"class EvalAgent(RoutedAgent):\n",
"\n",
" def __init__(\n",
" self,\n",
" groundedness_evaluator: GroundednessEvaluator,\n",
" relevance_evaluator: RelevanceEvaluator,\n",
" ) -> None:\n",
"\n",
" super().__init__(\"Eval Agent\")\n",
" self.groundedness_evaluator = groundedness_evaluator\n",
" self.relevance_evaluator = relevance_evaluator\n",
"\n",
" @message_handler\n",
" async def handle_message(self, message: Message, ctx: MessageContext) -> None:\n",
" print(f\"\\n{'-'*80}\\n{self.id.type} received a message:\\n\")\n",
" query_response = dict(\n",
" query=message.query, context=message.context, response=message.response\n",
" )\n",
"\n",
" groundedness_score = self.groundedness_evaluator(**query_response)\n",
" print(f\"groundness_score: {groundedness_score['groundedness']}\")\n",
" if groundedness_score[\"groundedness\"] < 3.0:\n",
" await self.publish_message(\n",
" AssistantMessage(content=INCORRECT_ANSWER, source=message.source),\n",
" topic_id=TopicId(type=user_topic_type, source=message.source),\n",
" )\n",
" relevance_score = self.relevance_evaluator(**query_response)\n",
"\n",
" print(f\"relevance_score: {relevance_score['relevance']}\")\n",
" if relevance_score[\"relevance\"] >= 3.0:\n",
" await self.publish_message(\n",
" AssistantMessage(content=message.response, source=message.source),\n",
" topic_id=TopicId(type=user_topic_type, source=message.source),\n",
" )\n",
" elif relevance_score[\"relevance\"] < 3.0:\n",
" await self.publish_message(\n",
" AssistantMessage(content=INCORRECT_ANSWER, source=message.source),\n",
" topic_id=TopicId(type=user_topic_type, source=message.source),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1078e59c",
"metadata": {},
"outputs": [],
"source": [
"KEYWORD_REWRITE_PROMPT = \"\"\"\n",
"You a keyword re-writer that converts an input question to a better version that is optimized for search. \n",
"Generate search keyword from a user query \n",
"to be more specific, detailed, and likely to retrieve relevant information, allowing for a more accurate response through web search.\n",
"Don't include the additional context from the user question.\n",
"\n",
"Query: {query}\n",
"Revised web search query:\n",
"\"\"\"\n",
"\n",
"\n",
"@type_subscription(topic_type=keyword_rewrite_topic_type)\n",
"class KeywordRewriteAgent(RoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
" super().__init__(\"Query Rewrite Agent\")\n",
" self._system_message = SystemMessage(\n",
" content=(\n",
" \"\"\"\n",
" You are an helper agent that can rewrite the query.\n",
" \"\"\"\n",
" )\n",
" )\n",
" self._model_client = model_client\n",
"\n",
" @message_handler\n",
" async def handle_message(self, message: Message, ctx: MessageContext) -> None:\n",
" print(f\"\\n{'-'*80}\\n{self.id.type} received a message:\\n\")\n",
" llm_result = await self._model_client.create(\n",
" messages=[\n",
" self._system_message,\n",
" UserMessage(\n",
" content=KEYWORD_REWRITE_PROMPT.format(query=message.query),\n",
" source=message.source,\n",
" ),\n",
" ],\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" response = llm_result.content\n",
"\n",
" await self.publish_message(\n",
" Message(\n",
" intent=message.intent,\n",
" query=response,\n",
" context=message.context,\n",
" source=message.source,\n",
" ),\n",
" topic_id=TopicId(type=web_search_topic_type, source=message.source),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f38deeb1",
"metadata": {},
"outputs": [],
"source": [
"INCORRECT_ANSWER = \"\"\"\n",
"Hello, and thank you for bringing this to our attention! I may have provided an inaccurate or misleading response, and I sincerely apologize for the confusion.\n",
"As an AI, I aim to deliver helpful and accurate information, but sometimes I might misinterpret or generate an incorrect response. Your feedback is invaluable and helps me improve.\n",
"\n",
"If you'd like, feel free to share more details or clarify your question, and I’ll do my best to assist you further. Thank you for your understanding and patience! ๐\n",
"\"\"\"\n",
"\n",
"\n",
"@type_subscription(topic_type=web_search_topic_type)\n",
"class WebSearchAgent(RoutedAgent):\n",
"\n",
" def __init__(\n",
" self,\n",
" web_search_tool: BingSearch,\n",
" retrieval_evaluator: RetrievalEvaluator,\n",
" ) -> None:\n",
"\n",
" super().__init__(\"WebSearch Agent\")\n",
" self.web_search_tool = web_search_tool\n",
" self.retrieval_evaluator = retrieval_evaluator\n",
"\n",
" @message_handler\n",
" async def handle_message(self, message: Message, ctx: MessageContext) -> None:\n",
" print(f\"\\n{'-'*80}\\n{self.id.type} received a message:\\n\")\n",
" search_results = web_search_tool.invoke({\"query\": message.query})\n",
" print(search_results)\n",
" try:\n",
" contents = []\n",
" items = list(search_results)\n",
" for i in range(min(5, len(items))):\n",
" doc = items[i]\n",
" contents.append(doc.get(\"content\", \"No content\"))\n",
" content = \"\\n\".join(contents)\n",
" except Exception as e:\n",
" print(f\"Error: {e}\")\n",
" content = \"No content\"\n",
"\n",
" print(\n",
" \"================================ search results ================================\"\n",
" )\n",
" print(content)\n",
"\n",
" search_response = dict(query=message.query, context=content)\n",
"\n",
" retrieval_score = self.retrieval_evaluator(**search_response)\n",
" print(f\"retrieval_score: {retrieval_score['retrieval']}\")\n",
"\n",
" if retrieval_score[\"retrieval\"] >= 3.0:\n",
" await self.publish_message(\n",
" Message(\n",
" intent=message.intent,\n",
" query=message.query,\n",
" context=content,\n",
" source=message.source,\n",
" ),\n",
" topic_id=TopicId(type=generate_topic_type, source=message.source),\n",
" )\n",
" elif retrieval_score[\"retrieval\"] < 2.0:\n",
" await self.publish_message(\n",
" AssistantMessage(content=INCORRECT_ANSWER, source=message.source),\n",
" topic_id=TopicId(type=user_topic_type, source=message.source),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d39cebcb",
"metadata": {},
"outputs": [],
"source": [
"def is_valid_json(data: str) -> bool:\n",
" if not isinstance(data, str):\n",
" return False\n",
" data = data.strip()\n",
" if not data:\n",
" return False\n",
" if not (data.startswith(\"{\") or data.startswith(\"[\")):\n",
" return False\n",
"\n",
"\n",
"@type_subscription(topic_type=user_topic_type)\n",
"class UserAgent(RoutedAgent):\n",
" def __init__(self) -> None:\n",
" super().__init__(\"A user agent that outputs the final copy to the user.\")\n",
"\n",
" @message_handler\n",
" async def handle_final_copy(\n",
" self, message: AssistantMessage, ctx: MessageContext\n",
" ) -> None:\n",
" print(f\"\\n{'-'*80}\\n{self.id.type} received a message:\\n\")\n",
"\n",
" if is_valid_json(message.content):\n",
" response_content = json.loads(message.content)\n",
" for recommendation in response_content[\"recommendation\"]:\n",
" print(recommendation)\n",
" else:\n",
" print(message.content)"
]
},
{
"cell_type": "markdown",
"id": "2a8fce6e",
"metadata": {},
"source": [
"<br>\n",
"\n",
"## ๐งช Step 3. Execute the Workflow\n",
"---\n",
"\n",
"### Execute the workflow"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8c329af2",
"metadata": {},
"outputs": [],
"source": [
"runtime = SingleThreadedAgentRuntime()\n",
"\n",
"await IntentRouterAgent.register(\n",
" runtime,\n",
" type=user_query_topic_type,\n",
" factory=lambda: IntentRouterAgent(model_client=autogen_aoai_client),\n",
")\n",
"\n",
"await RAGGraderAgent.register(\n",
" runtime,\n",
" type=rag_grader_topic_type,\n",
" factory=lambda: RAGGraderAgent(\n",
" azure_ai_search_endpoint=azure_ai_search_endpoint,\n",
" azure_search_admin_key=azure_search_admin_key,\n",
" index_name=index_name,\n",
" retrieval_evaluator=RetrievalEvaluator(model_config),\n",
" ),\n",
")\n",
"\n",
"await QueryRewriteAgent.register(\n",
" runtime,\n",
" type=query_rewrite_topic_type,\n",
" factory=lambda: QueryRewriteAgent(model_client=autogen_aoai_client),\n",
")\n",
"\n",
"await GenerateAgent.register(\n",
" runtime,\n",
" type=generate_topic_type,\n",
" factory=lambda: GenerateAgent(model_client=autogen_aoai_client),\n",
")\n",
"\n",
"await EvalAgent.register(\n",
" runtime,\n",
" type=eval_topic_type,\n",
" factory=lambda: EvalAgent(\n",
" groundedness_evaluator=GroundednessEvaluator(model_config),\n",
" relevance_evaluator=RelevanceEvaluator(model_config),\n",
" ),\n",
")\n",
"\n",
"WEB_SEARCH_FORMAT_OUTPUT = False\n",
"\n",
"await KeywordRewriteAgent.register(\n",
" runtime,\n",
" type=keyword_rewrite_topic_type,\n",
" factory=lambda: KeywordRewriteAgent(model_client=autogen_aoai_client),\n",
")\n",
"\n",
"await WebSearchAgent.register(\n",
" runtime,\n",
" type=web_search_topic_type,\n",
" factory=lambda: WebSearchAgent(\n",
" web_search_tool=BingSearch(\n",
" max_results=3,\n",
" locale=\"en-US\",\n",
" include_news=False,\n",
" include_entity=False,\n",
" format_output=WEB_SEARCH_FORMAT_OUTPUT,\n",
" ),\n",
" retrieval_evaluator=RetrievalEvaluator(model_config),\n",
" ),\n",
")\n",
"\n",
"\n",
"await UserAgent.register(runtime, type=user_topic_type, factory=lambda: UserAgent())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "32562244",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"start_time = time.perf_counter()\n",
"\n",
"runtime.start()\n",
"\n",
"# await runtime.publish_message(Message(query=\"what is the history of Microsoft?\", source=\"User\"), topic_id=TopicId(type=user_query_topic_type, source=\"user\"))\n",
"# await runtime.publish_message(Message(query=\"can you recommend a hoels\", source=\"User\"), topic_id=TopicId(type=user_query_topic_type, source=\"user\"))\n",
"await runtime.publish_message(\n",
" Message(query=\"can you recommend some hotels with free wifi?\", source=\"User\"),\n",
" topic_id=TopicId(type=user_query_topic_type, source=\"user\"),\n",
")\n",
"# await runtime.publish_message(Message(query=\"can you tell me openning hotels in NY 2025?\", source=\"User\"), topic_id=TopicId(type=user_query_topic_type, source=\"user\"))\n",
"\n",
"await runtime.stop_when_idle()\n",
"\n",
"\n",
"end_time = time.perf_counter()\n",
"print(f\"Elapsed time: {end_time - start_time} seconds\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv_agent",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}