1_agentic-design-ptn/01_reflection/AutoGen/01.1_self-rag-eval-sdk.ipynb (883 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "c26ab996",
"metadata": {},
"source": [
"# Self-RAG\n",
"---\n",
"\n",
"### What is Self-RAG?\n",
"\n",
"Self-RAG reflects on the retrieved documents and generated responses, and includes a self-evaluation process to improve the quality of the generated answers.\n",
"\n",
"Original paper says Self-RAG generates special tokens, termed \"reflection tokens,\" to determine if retrieval would enhance the response, allowing for on-demand retrieval integration. \n",
"But in practice, we can ignore reflection tokens and let LLM decides if each document is relevant or not.\n",
"\n",
"Corrective RAG (CRAG) is similar to Self-RAG, but Self-RAG focuses on self-reflection and self-evaluation, while CRAG focuses on refining the entire retrieval process including web search.\n",
"\n",
"- Self-RAG: Trains the LLM to be self-sufficient in managing retrieval and generation processes. By generating reflection tokens, the model controls its behavior during inference, deciding when to retrieve information and how to critique and improve its own responses, leading to more accurate and contextually appropriate outputs. \n",
"- CRAG: Focuses on refining the retrieval process by evaluating and correcting the retrieved documents before they are used in generation. It integrates additional retrievals, such as web searches, when initial retrievals are insufficient, ensuring that the generation is based on the most relevant and accurate information available.\n",
"\n",
"**Reference**\n",
"\n",
"- [Self-RAG paper](https://arxiv.org/abs/2310.11511) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6458235",
"metadata": {},
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"import os\n",
"import json\n",
"from azure.core.credentials import AzureKeyCredential\n",
"from azure.identity import DefaultAzureCredential\n",
"from azure.search.documents import SearchClient\n",
"from azure.search.documents.models import VectorizableTextQuery\n",
"from azure.ai.evaluation import (\n",
" GroundednessEvaluator,\n",
" RelevanceEvaluator,\n",
" RetrievalEvaluator,\n",
")\n",
"from autogen_ext.models.openai import AzureOpenAIChatCompletionClient\n",
"from autogen_core.models import (\n",
" ChatCompletionClient,\n",
" SystemMessage,\n",
" UserMessage,\n",
" AssistantMessage,\n",
")\n",
"from autogen_core import (\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" TopicId,\n",
" message_handler,\n",
" type_subscription,\n",
")\n",
"from pydantic import BaseModel\n",
"from typing import List\n",
"from dataclasses import dataclass\n",
"\n",
"\n",
"load_dotenv(override=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05d7dfed",
"metadata": {},
"outputs": [],
"source": [
"# Get the environment variables\n",
"azure_ai_search_endpoint = os.getenv(\"AZURE_SEARCH_SERVICE_ENDPOINT\")\n",
"search_credential = (\n",
" AzureKeyCredential(os.getenv(\"AZURE_SEARCH_ADMIN_KEY\", \"\"))\n",
" if len(os.getenv(\"AZURE_SEARCH_ADMIN_KEY\", \"\")) > 0\n",
" else DefaultAzureCredential()\n",
")\n",
"index_name = os.getenv(\"AZURE_SEARCH_INDEX_NAME\", \"hotels-sample-index\")\n",
"\n",
"azure_openai_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\")\n",
"azure_openai_key = (\n",
" os.getenv(\"AZURE_OPENAI_API_KEY\", \"\")\n",
" if len(os.getenv(\"AZURE_OPENAI_API_KEY\", \"\")) > 0\n",
" else None\n",
")\n",
"azure_openai_chat_deployment_name = os.getenv(\"AZURE_OPENAI_CHAT_DEPLOYMENT_NAME\")\n",
"azure_openai_embedding_deployment_name = os.getenv(\n",
" \"AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME\", \"text-embedding-ada-002\"\n",
")\n",
"azure_openai_api_version = os.getenv(\"AZURE_OPENAI_API_VERSION\", \"2024-06-01\")\n",
"\n",
"model_config = {\n",
" \"azure_endpoint\": azure_openai_endpoint,\n",
" \"api_key\": azure_openai_key,\n",
" \"azure_deployment\": azure_openai_chat_deployment_name,\n",
" \"api_version\": azure_openai_api_version,\n",
" \"type\": \"azure_openai\",\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "5ec192a9",
"metadata": {},
"source": [
"## 🧪 Step 1. Test and Construct each module\n",
"---\n",
"\n",
"Before building the entire the graph pipeline, we will test and construct each module separately.\n",
"\n",
"- **Retrieval Grader**\n",
"- **Answer Generator**\n",
"- **Groundedness Evaluator**\n",
"- **Relevance Evaluator**\n",
"- **Question Re-writer**\n",
"\n",
"### Construct Retrieval Chain based on PDF\n",
"- We use the hotels-sample-index, which can be created in minutes and runs on any search service tier. This index is created by a wizard using built-in sample data."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cde195c",
"metadata": {},
"outputs": [],
"source": [
"azure_ai_search_endpoint = os.getenv(\"AZURE_AI_SEARCH_ENDPOINT\")\n",
"azure_search_admin_key = os.getenv(\"AZURE_SEARCH_ADMIN_KEY\", \"\")\n",
"search_client = SearchClient(\n",
" endpoint=azure_ai_search_endpoint,\n",
" index_name=index_name,\n",
" credential=AzureKeyCredential(azure_search_admin_key),\n",
" semantic_configuration_name=\"my-semantic-config\",\n",
")\n",
"\n",
"# Query is the question being asked. It's sent to the search engine and the LLM.\n",
"query = \"Can you recommend a few hotels with complimentary breakfast?\"\n",
"\n",
"fields = \"descriptionVector\" # TODO: Check if this is the correct field name\n",
"# don't use exhaustive search for large indexes\n",
"vector_query = VectorizableTextQuery(\n",
" text=query, k_nearest_neighbors=2, fields=fields, exhaustive=True\n",
")\n",
"\n",
"# Search results are created by the search client.\n",
"# Search results are composed of the top 3 results and the fields selected from the search index.\n",
"# Search results include the top 3 matches to your query.\n",
"search_results = search_client.search(\n",
" search_text=query,\n",
" vector_queries=[vector_query],\n",
" select=\"Description,HotelName,Tags\",\n",
" top=3,\n",
")\n",
"sources_formatted = \"\\n\".join(\n",
" [\n",
" f'{document[\"HotelName\"]}:{document[\"Description\"]}:{document[\"Tags\"]}'\n",
" for document in search_results\n",
" ]\n",
")\n",
"\n",
"print(sources_formatted)"
]
},
{
"cell_type": "markdown",
"id": "083d39ad",
"metadata": {},
"source": [
"### Define your LLM\n",
"\n",
"This hands-on only uses the `gpt-4o-mini`, but you can utilize multiple models in the pipeline."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18d3a8f8",
"metadata": {},
"outputs": [],
"source": [
"# aoai_client = AzureOpenAI(\n",
"# azure_endpoint=azure_openai_endpoint,\n",
"# api_key=azure_openai_key,\n",
"# api_version=openai_api_version,\n",
"# )\n",
"\n",
"# This is not the same object as the one above. This is the client that is used to interact with the Azure OpenAI Chat API.\n",
"autogen_aoai_client = AzureOpenAIChatCompletionClient(\n",
" azure_endpoint=azure_openai_endpoint,\n",
" model=azure_openai_chat_deployment_name,\n",
" api_version=azure_openai_api_version,\n",
" api_key=azure_openai_key,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "65ec10e5",
"metadata": {},
"source": [
"### Question-Retrieval Grader\n",
"\n",
"Construct a retrieval grader that evaluates the relevance of the retrieved documents to the input question. The retrieval grader should take the input question and the retrieved documents as input and output a relevance score for each document.<br>\n",
"Note that the retrieval grader should be able to handle **multiple documents** as input."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c85e4c2",
"metadata": {},
"outputs": [],
"source": [
"retrieval_eval = RetrievalEvaluator(model_config)\n",
"\n",
"query_response = dict(query=query, context=sources_formatted)\n",
"\n",
"relevance_score = retrieval_eval(**query_response)\n",
"print(relevance_score)\n",
"relevance_score[\"retrieval\"]"
]
},
{
"cell_type": "markdown",
"id": "702b7f87",
"metadata": {},
"source": [
"### Answer Generator\n",
"\n",
"Construct a LLM Generation node. This is a Naive RAG chain that generates an answer based on the retrieved documents. \n",
"\n",
"We recommend you to use more advanced RAG chain for production"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2168105c",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel\n",
"from typing import List\n",
"\n",
"\n",
"class HotelInfo(BaseModel):\n",
" hotel_name: str\n",
" description: str\n",
"\n",
"\n",
"class RecommendationList(BaseModel):\n",
" recommendation: List[HotelInfo]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d080f695",
"metadata": {},
"outputs": [],
"source": [
"# This prompt provides instructions to the model\n",
"GROUNDED_PROMPT = \"\"\"\n",
"You are a friendly assistant that recommends hotels based on activities and amenities.\n",
"Answer the query using only the context provided below in a friendly and concise bulleted manner.\n",
"Answer ONLY with the facts listed in the list of context below.\n",
"If there isn't enough information below, say you don't know.\n",
"Generate a response that includes the top 3 results.\n",
"Do not generate answers that don't use the context below.\n",
"Query: {query}\n",
"Context:\\n{context}\n",
"\"\"\"\n",
"\n",
"# Send the search results and the query to the LLM to generate a response based on the prompt.\n",
"response = await autogen_aoai_client.create(\n",
" messages=[\n",
" UserMessage(\n",
" content=GROUNDED_PROMPT.format(query=query, context=sources_formatted),\n",
" source=\"user\",\n",
" ),\n",
" ],\n",
" extra_create_args={\"response_format\": RecommendationList},\n",
")\n",
"\n",
"response_content = json.loads(response.content)\n",
"for recommendation in response_content[\"recommendation\"]:\n",
" print(recommendation)"
]
},
{
"cell_type": "markdown",
"id": "41d4050d",
"metadata": {},
"source": [
"### Groundedness Evaluator\n",
"\n",
"Construct a `groundedness_grader` node to evaluate the **hallucination** of the generated answer based on the retrieved documents.<br>\n",
"\n",
"`yes` means the answer is relevant to the retrieved documents, and `no` means the answer is not relevant to the retrieved documents."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f46b9ae",
"metadata": {},
"outputs": [],
"source": [
"groundedness_eval = GroundednessEvaluator(model_config)\n",
"\n",
"query_response = dict(query=query, context=sources_formatted, response=response_content)\n",
"\n",
"groundedness_score = groundedness_eval(**query_response)\n",
"print(groundedness_score)"
]
},
{
"cell_type": "markdown",
"id": "54c773d2",
"metadata": {},
"source": [
"### Relevance Evaluator\n",
"\n",
"Construct a `relevance_grader` node to evaluate the relevance of the generated answer to the question.<br>\n",
"`yes` means the answer is relevant to the question, and `no` means the answer is not relevant to the question."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "37e10188",
"metadata": {},
"outputs": [],
"source": [
"relevance_eval = RelevanceEvaluator(model_config)\n",
"\n",
"query_response = dict(query=query, response=response_content)\n",
"\n",
"relevance_score = relevance_eval(**query_response)\n",
"print(relevance_score)"
]
},
{
"cell_type": "markdown",
"id": "fd05275c",
"metadata": {},
"source": [
"### Question Re-writer\n",
"\n",
"Construct a `question_rewriter` node to rewrite the question based on the retrieved documents and the generated answer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "806f6d04",
"metadata": {},
"outputs": [],
"source": [
"query = \"Can you recommend a few factories with complimentary breakfast?\"\n",
"\n",
"# This prompt provides instructions to the model\n",
"REWRITE_PROMPT = \"\"\"\n",
"You a question re-writer that converts an input question to a better version that is optimized\n",
"for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.\n",
"Query: {query}\n",
"\"\"\"\n",
"\n",
"# Send the search results and the query to the LLM to generate a response based on the prompt.\n",
"response = await autogen_aoai_client.create(\n",
" messages=[\n",
" UserMessage(content=REWRITE_PROMPT.format(query=query), source=\"user\"),\n",
" ]\n",
")\n",
"\n",
"\n",
"# Here is the response from the chat model.\n",
"print(response.content)"
]
},
{
"cell_type": "markdown",
"id": "f08d720d",
"metadata": {},
"source": [
"<br>\n",
"\n",
"## 🧪 Step 2. Define the Agentic Architecture\n",
"---\n",
"- Before building the agentic pipeline, we need to design the message, topic, agent and message routing logic. \n",
"- You should define the terminate condition for the pipeline.\n",
"\n",
"### Message, Topic, Agent Definition\n",
"\n",
"```markdown\n",
"```python\n",
"\n",
"# Message Definition\n",
"@dataclass\n",
"class Message:\n",
" query: str = None\n",
" context: str = None\n",
" response: str = None\n",
" source: str = None\n",
"\n",
"\n",
"# Topic Definition\n",
"user_query_topic_type = \"UserQueryTopic\"\n",
"rewrite_topic_type = \"RewriteQueryTopic\"\n",
"generate_topic_type = \"GenerateTopic\"\n",
"eval_topic_type = \"EvalTopic\"\n",
"\n",
"# Agent Definition\n",
"class RetrievalGraderAgent(RoutedAgent):\n",
"class RewriteQueryAgent(RoutedAgent):\n",
"class GenerateAgent(RoutedAgent):\n",
"class EvalAgent(RoutedAgent):\n",
"class UserAgent(RoutedAgent):\n",
"\n",
"```\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "687e97e4",
"metadata": {},
"source": [
"Visualizing the abstract architecture of the pipeline will help you understand the message flow and the agent's role in the pipeline."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99640cec",
"metadata": {},
"outputs": [],
"source": [
"from azure_genai_utils.graphs import visualize_agents\n",
"\n",
"agents = [\n",
" \"Start\",\n",
" \"RAGGraderAgent\",\n",
" \"QueryRewriteAgent\",\n",
" \"GenerateAgent\",\n",
" \"EvalAgent\",\n",
" \"UserAgent\",\n",
"]\n",
"interactions = [\n",
" (\"Start\", \"RAGGraderAgent\"),\n",
" (\"RAGGraderAgent\", \"GenerateAgent\", \"Generates Response\"),\n",
" (\"RAGGraderAgent\", \"QueryRewriteAgent\", \"Rewrites Query\"),\n",
" (\"QueryRewriteAgent\", \"GenerateAgent\"),\n",
" (\"GenerateAgent\", \"EvalAgent\"),\n",
" (\"EvalAgent\", \"UserAgent\"),\n",
"]\n",
"\n",
"visualize_agents(agents, interactions)"
]
},
{
"cell_type": "markdown",
"id": "ae92e46c",
"metadata": {},
"source": [
"This is an example of visualized pipeline\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87d1f724",
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Message:\n",
" query: str = None\n",
" context: str = None\n",
" response: str = None\n",
" source: str = None\n",
"\n",
" def set_source(self, source: str) -> \"Message\":\n",
" self.source = source\n",
" return self\n",
"\n",
"\n",
"# Topic Definition\n",
"user_query_topic_type = \"UserQueryTopic\"\n",
"query_rewrite_topic_type = \"QueryRewriteAgent\"\n",
"generate_topic_type = \"GenerateAgent\"\n",
"eval_topic_type = \"EvalAgent\"\n",
"user_topic_type = \"UserAgent\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "883dde68",
"metadata": {},
"outputs": [],
"source": [
"@type_subscription(topic_type=user_query_topic_type)\n",
"class RAGGraderAgent(RoutedAgent):\n",
"\n",
" def __init__(\n",
" self,\n",
" azure_ai_search_endpoint: str,\n",
" azure_search_admin_key: str,\n",
" index_name: str,\n",
" retrieval_evaluator: RetrievalEvaluator,\n",
" ) -> None:\n",
"\n",
" super().__init__(\"RAG Grader Agent\")\n",
" self.index_name = index_name\n",
" self.azure_ai_search_endpoint = azure_ai_search_endpoint\n",
" self.azure_search_admin_key = azure_search_admin_key\n",
" self.retrieval_evaluator = retrieval_evaluator\n",
"\n",
" def config_search(self) -> SearchClient:\n",
" service_endpoint = self.azure_ai_search_endpoint\n",
" key = self.azure_search_admin_key\n",
" index_name = self.index_name\n",
" credential = AzureKeyCredential(key)\n",
" return SearchClient(\n",
" endpoint=service_endpoint, index_name=index_name, credential=credential\n",
" )\n",
"\n",
" async def do_search(self, query: str) -> str:\n",
" \"\"\"Search indexed data using Azure Cognitive Search with vector-based queries.\"\"\"\n",
" aia_search_client = self.config_search()\n",
"\n",
" fields = \"descriptionVector\" # TODO: Check if this is the correct field name\n",
" # don't use exhaustive search for large indexes\n",
" vector_query = VectorizableTextQuery(\n",
" text=query, k_nearest_neighbors=1, fields=fields, exhaustive=True\n",
" )\n",
"\n",
" search_results = aia_search_client.search(\n",
" search_text=query,\n",
" vector_queries=[vector_query],\n",
" select=[\n",
" \"Description,HotelName,Tags\"\n",
" ], # TODO: Check if these are the correct field names\n",
" top=3, # TODO: Check if this is the correct number of results\n",
" )\n",
" answer = \"\\n\".join(\n",
" [\n",
" f'{document[\"HotelName\"]}:{document[\"Description\"]}:{document[\"Tags\"]}'\n",
" for document in search_results\n",
" ]\n",
" )\n",
" return answer\n",
"\n",
" @message_handler\n",
" async def handle_message(self, message: Message, ctx: MessageContext) -> None:\n",
"\n",
" context_from_ai_search = await self.do_search(message.query)\n",
" print(context_from_ai_search)\n",
"\n",
" query_response = dict(query=query, context=context_from_ai_search)\n",
"\n",
" retrieval_score = self.retrieval_evaluator(**query_response)\n",
"\n",
" print(f\"retrieval_score: {retrieval_score['retrieval']}\")\n",
"\n",
" if retrieval_score[\"retrieval\"] >= 3.0:\n",
" await self.publish_message(\n",
" Message(\n",
" query=query, context=context_from_ai_search, source=message.source\n",
" ),\n",
" topic_id=TopicId(type=generate_topic_type, source=message.source),\n",
" )\n",
" else:\n",
" await self.publish_message(\n",
" Message(\n",
" query=query, context=context_from_ai_search, source=message.source\n",
" ),\n",
" topic_id=TopicId(type=query_rewrite_topic_type, source=message.source),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "452355a4",
"metadata": {},
"outputs": [],
"source": [
"REWRITE_PROMPT = \"\"\"\n",
"You a question re-writer that converts an input question to a better version that is optimized\n",
"for vectorstore retrieval. Look at the input and try to rewrite about the underlying semantic intent / meaning.\n",
"Query: {query}\n",
"\"\"\"\n",
"\n",
"\n",
"@type_subscription(topic_type=query_rewrite_topic_type)\n",
"class QueryRewriteAgent(RoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
" super().__init__(\"Query Rewrite Agent\")\n",
" self._system_message = SystemMessage(\n",
" content=(\n",
" \"\"\"\n",
" You are an helper agent that can rewrite the query.\n",
" \"\"\"\n",
" )\n",
" )\n",
" self._model_client = model_client\n",
"\n",
" @message_handler\n",
" async def handle_message(self, message: Message, ctx: MessageContext) -> None:\n",
" print(message)\n",
" llm_result = await self._model_client.create(\n",
" messages=[\n",
" self._system_message,\n",
" UserMessage(\n",
" content=REWRITE_PROMPT.format(query=message.query),\n",
" source=message.source,\n",
" ),\n",
" ],\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" response = llm_result.content\n",
" print(response)\n",
" assert isinstance(response, str)\n",
" print(f\"{'-'*80}\\n{self.id.type}:\\n{response}\")\n",
"\n",
" await self.publish_message(\n",
" Message(query=response, context=message.context, source=message.source),\n",
" topic_id=TopicId(type=generate_topic_type, source=message.source),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93e9a219",
"metadata": {},
"outputs": [],
"source": [
"# This prompt provides instructions to the model\n",
"GROUNDED_PROMPT = \"\"\"\n",
"Answer the query using only the context provided below in a friendly and concise bulleted manner.\n",
"Answer ONLY with the facts listed in the list of context below.\n",
"If there isn't enough information below, say you don't know.\n",
"Do not generate answers that don't use the context below.\n",
"Query: {query}\n",
"Context:\\n{context}\n",
"\"\"\"\n",
"\n",
"\n",
"@type_subscription(topic_type=generate_topic_type)\n",
"class GenerateAgent(RoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
" super().__init__(\"Generate Agent\")\n",
" self._system_message = SystemMessage(\n",
" content=(\n",
" \"\"\"\n",
" You are a friendly assistant that recommends hotels based on activities and amenities.\n",
" \"\"\"\n",
" )\n",
" )\n",
" self._model_client = model_client\n",
"\n",
" @message_handler\n",
" async def handle_message(self, message: Message, ctx: MessageContext) -> None:\n",
" llm_result = await self._model_client.create(\n",
" messages=[\n",
" self._system_message,\n",
" UserMessage(\n",
" content=GROUNDED_PROMPT.format(\n",
" query=message.query, context=message.context\n",
" ),\n",
" source=message.source,\n",
" ),\n",
" ],\n",
" extra_create_args={\"response_format\": RecommendationList},\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" response = llm_result.content\n",
" assert isinstance(response, str)\n",
" print(f\"{'-'*80}\\n{self.id.type}:\\n{response}\")\n",
" await self.publish_message(\n",
" Message(\n",
" query=message.query,\n",
" context=message.context,\n",
" response=response,\n",
" source=message.source,\n",
" ),\n",
" topic_id=TopicId(type=eval_topic_type, source=self.id.key),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "942babd6",
"metadata": {},
"outputs": [],
"source": [
"INCORRECT_ANSWER = \"\"\"\n",
"Hello, and thank you for bringing this to our attention! I may have provided an inaccurate or misleading response, and I sincerely apologize for the confusion.\n",
"As an AI, I aim to deliver helpful and accurate information, but sometimes I might misinterpret or generate an incorrect response. Your feedback is invaluable and helps me improve.\n",
"\n",
"If you'd like, feel free to share more details or clarify your question, and I’ll do my best to assist you further. Thank you for your understanding and patience! 😊\n",
"\"\"\"\n",
"\n",
"\n",
"@type_subscription(topic_type=eval_topic_type)\n",
"class EvalAgent(RoutedAgent):\n",
"\n",
" def __init__(\n",
" self,\n",
" groundedness_evaluator: GroundednessEvaluator,\n",
" relevance_evaluator: RelevanceEvaluator,\n",
" ) -> None:\n",
"\n",
" super().__init__(\"Eval Agent\")\n",
" self.groundedness_evaluator = groundedness_evaluator\n",
" self.relevance_evaluator = relevance_evaluator\n",
"\n",
" @message_handler\n",
" async def handle_message(self, message: Message, ctx: MessageContext) -> None:\n",
"\n",
" query_response = dict(\n",
" query=message.query, context=message.context, response=message.response\n",
" )\n",
"\n",
" groundedness_score = self.groundedness_evaluator(**query_response)\n",
" print(f\"groundness_score: {groundedness_score['groundedness']}\")\n",
" if groundedness_score[\"groundedness\"] < 3.0:\n",
" await self.publish_message(\n",
" AssistantMessage(content=INCORRECT_ANSWER, source=message.source),\n",
" topic_id=TopicId(type=user_topic_type, source=message.source),\n",
" )\n",
" relevance_score = self.relevance_evaluator(**query_response)\n",
" print(f\"relevance_score: {relevance_score['relevance']}\")\n",
" if relevance_score[\"relevance\"] < 3.0:\n",
" await self.publish_message(\n",
" AssistantMessage(content=INCORRECT_ANSWER, source=message.source),\n",
" topic_id=TopicId(type=user_topic_type, source=message.source),\n",
" )\n",
"\n",
" await self.publish_message(\n",
" AssistantMessage(content=message.response, source=message.source),\n",
" topic_id=TopicId(type=user_topic_type, source=message.source),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d39cebcb",
"metadata": {},
"outputs": [],
"source": [
"@type_subscription(topic_type=user_topic_type)\n",
"class UserAgent(RoutedAgent):\n",
" def __init__(self) -> None:\n",
" super().__init__(\"A user agent that outputs the final copy to the user.\")\n",
"\n",
" @message_handler\n",
" async def handle_final_copy(\n",
" self, message: AssistantMessage, ctx: MessageContext\n",
" ) -> None:\n",
" print(f\"\\n{'-'*80}\\n{self.id.type} received final copy:\\n{message.content}\")"
]
},
{
"cell_type": "markdown",
"id": "2a8fce6e",
"metadata": {},
"source": [
"<br>\n",
"\n",
"## 🧪 Step 3. Execute the Workflow\n",
"---\n",
"\n",
"### Execute the workflow"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8c329af2",
"metadata": {},
"outputs": [],
"source": [
"runtime = SingleThreadedAgentRuntime()\n",
"\n",
"await RAGGraderAgent.register(\n",
" runtime,\n",
" type=user_query_topic_type,\n",
" factory=lambda: RAGGraderAgent(\n",
" azure_ai_search_endpoint=azure_ai_search_endpoint,\n",
" azure_search_admin_key=azure_search_admin_key,\n",
" index_name=index_name,\n",
" retrieval_evaluator=RetrievalEvaluator(model_config),\n",
" ),\n",
")\n",
"\n",
"await QueryRewriteAgent.register(\n",
" runtime,\n",
" type=query_rewrite_topic_type,\n",
" factory=lambda: QueryRewriteAgent(model_client=autogen_aoai_client),\n",
")\n",
"\n",
"await GenerateAgent.register(\n",
" runtime,\n",
" type=generate_topic_type,\n",
" factory=lambda: GenerateAgent(model_client=autogen_aoai_client),\n",
")\n",
"\n",
"await EvalAgent.register(\n",
" runtime,\n",
" type=eval_topic_type,\n",
" factory=lambda: EvalAgent(\n",
" groundedness_evaluator=GroundednessEvaluator(model_config),\n",
" relevance_evaluator=RelevanceEvaluator(model_config),\n",
" ),\n",
")\n",
"\n",
"\n",
"await UserAgent.register(runtime, type=user_topic_type, factory=lambda: UserAgent())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "32562244",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"start_time = time.perf_counter()\n",
"\n",
"runtime.start()\n",
"await runtime.publish_message(\n",
" Message(\n",
" query=\"Can you recommend a few factory with complimentary breakfast?\",\n",
" source=\"User\",\n",
" ),\n",
" topic_id=TopicId(type=user_query_topic_type, source=\"user\"),\n",
")\n",
"await runtime.stop_when_idle()\n",
"\n",
"end_time = time.perf_counter()\n",
"print(f\"Elapsed time: {end_time - start_time} seconds\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv_agent",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}