1_agentic-design-ptn/01_reflection/LangGraph/03.2_adaptive-rag-eval-sdk.ipynb (868 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "635d8ebb",
"metadata": {},
"source": [
"# Adaptive RAG with Azure AI Evaluation SDK\n",
"\n",
"----\n",
"\n",
"In this notebook, we will demonstrate how to use the Azure AI Evaluation SDK.\n",
"\n",
"> ✨ ***Note*** <br>\n",
"> 1. Please check the reference document before you get started - https://learn.microsoft.com/en-us/azure/ai-studio/how-to/develop/evaluate-sdk <br>\n",
"> 2. Check the region support for the Azure AI Evaluation SDK. https://learn.microsoft.com/en-us/azure/ai-studio/concepts/evaluation-metrics-built-in?tabs=warning#region-support\n",
"\n",
"\n",
"Adaptive RAG predicts the **complexity of the input question** using a SLM/LLM and selects an appropriate processing workflow accordingly.\n",
"\n",
"- **Very simple question (No Retrieval)**: Generates answers without RAG.\n",
"- **Simple question (Single-shot RAG)**: Efficiently generates answers through a single-step search and generation.\n",
"- **Complex question (Iterative RAG)**: Provides accurate answers to complex questions through repeated multi-step search and generation.\n",
"\n",
"\n",
"Adaptive-RAG, Self-RAG, and Corrective RAG are similar approach, but they have different focuses.\n",
"\n",
"- **Adaptive-RAG**: Dynamically selects appropriate retrieval and generation strategies based on the complexity of the question.\n",
"- **Self-RAG**: The model determines the need for retrieval on its own, performs retrieval when necessary, and improves the quality through self-reflection on the generated answers.\n",
"- **Corrective RAG**: Evaluates the quality of retrieved documents, and performs additional retrievals such as web searches to supplement the information if the reliability is low.\n",
"\n",
"**Reference**\n",
"\n",
"- [Adaptive-RAG paper](https://arxiv.org/abs/2403.14403) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f25ec196",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from dotenv import load_dotenv\n",
"from azure_genai_utils.tracer import get_langchain_api_key, set_langsmith\n",
"\n",
"load_dotenv(override=True)\n",
"\n",
"# If you want to trace your RAG API calls, please set the tracing=True. You need to have a valid Langchain API key.\n",
"langchain_key, has_langchain_key = get_langchain_api_key()\n",
"set_langsmith(\"[RAG Innv Lab] 1_Agentic-Design-Pattern\", tracing=False)\n",
"\n",
"azure_openai_chat_deployment_name = os.getenv(\"AZURE_OPENAI_CHAT_DEPLOYMENT_NAME\")\n",
"azure_openai_embedding_deployment_name = os.getenv(\"AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "45f25be1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pprint\n",
"from azure.identity import DefaultAzureCredential\n",
"from azure.ai.evaluation import evaluate\n",
"from azure.ai.evaluation import (\n",
" ContentSafetyEvaluator,\n",
" RelevanceEvaluator,\n",
" CoherenceEvaluator,\n",
" GroundednessEvaluator,\n",
" FluencyEvaluator,\n",
" SimilarityEvaluator,\n",
" F1ScoreEvaluator,\n",
" RetrievalEvaluator,\n",
")\n",
"\n",
"credential = DefaultAzureCredential()\n",
"\n",
"# Initialize Azure OpenAI conncetion with your environment variables\n",
"\n",
"model_config = {\n",
" \"azure_endpoint\": os.environ.get(\"AZURE_OPENAI_ENDPOINT\"),\n",
" \"api_key\": os.environ.get(\"AZURE_OPENAI_API_KEY\"),\n",
" \"azure_deployment\": os.environ.get(\"AZURE_OPENAI_DEPLOYMENT_NAME\"),\n",
" \"api_version\": os.environ.get(\"AZURE_OPENAI_API_VERSION\"),\n",
" \"type\": \"azure_openai\",\n",
"}\n",
"\n",
"pprint.pprint(model_config)"
]
},
{
"cell_type": "markdown",
"id": "aa00c3f4",
"metadata": {},
"source": [
"<br>\n",
"\n",
"## 🧪 Step 1. Test and Construct each module\n",
"---\n",
"\n",
"### Construct Retrieval Chain based on PDF"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69cb77da",
"metadata": {},
"outputs": [],
"source": [
"from azure_genai_utils.rag.pdf import PDFRetrievalChain\n",
"\n",
"pdf_path = \"../../../sample-docs/AutoGen-paper.pdf\"\n",
"\n",
"pdf = PDFRetrievalChain(\n",
" source_uri=[pdf_path],\n",
" loader_type=\"PDFPlumber\",\n",
" model_name=azure_openai_chat_deployment_name,\n",
" embedding_name=azure_openai_embedding_deployment_name,\n",
" chunk_size=500,\n",
" chunk_overlap=50,\n",
").create_chain()\n",
"\n",
"pdf_retriever = pdf.retriever\n",
"pdf_chain = pdf.chain\n",
"\n",
"question = \"What is AutoGen's main features?\"\n",
"docs = pdf_retriever.invoke(question)\n",
"\n",
"# Non-streaming\n",
"# results = pdf_chain.invoke({\"chat_history\": \"\", \"question\": question, \"context\": docs})\n",
"\n",
"# Streaming\n",
"for text in pdf_chain.stream(\n",
" {\"chat_history\": \"\", \"question\": question, \"context\": docs}\n",
"):\n",
" print(text, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"id": "2b2fc536",
"metadata": {},
"source": [
"### Query Routing and Document Evaluation\n",
"\n",
"Adaptive RAG performs query routing and document evaluation to provide accurate and reliable information. This process is essential for maximizing the performance of LLMs.\n",
"\n",
"- **Query Routing**: Analyze user queries to route them to appropriate information sources. This allows you to set the optimal search path for the purpose of the query.\n",
"- **Document Evaluation**: Evaluate the quality and relevance of retrieved documents to increase the accuracy of the final results."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b78d33f",
"metadata": {},
"outputs": [],
"source": [
"from typing import Literal\n",
"\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from pydantic import BaseModel, Field\n",
"from langchain_openai import AzureChatOpenAI\n",
"\n",
"\n",
"class RouteQuery(BaseModel):\n",
" \"\"\"Route a user query to the most relevant datasource.\"\"\"\n",
"\n",
" datasource: Literal[\"vectorstore\", \"web_search\"] = Field(\n",
" ...,\n",
" description=\"Given a user question choose to route it to web search or a vectorstore.\",\n",
" )\n",
"\n",
"\n",
"llm = AzureChatOpenAI(model=azure_openai_chat_deployment_name, temperature=0)\n",
"structured_llm_router = llm.with_structured_output(RouteQuery)\n",
"\n",
"TOPIC = \"AutoGen\"\n",
"system = f\"\"\"You are an expert at routing a user question to a vectorstore or web search.\n",
"The vectorstore contains documents related to {TOPIC}.\n",
"Use the vectorstore for questions on these topics. Otherwise, use web-search.\"\"\"\n",
"\n",
"\n",
"route_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", system),\n",
" (\"human\", \"{question}\"),\n",
" ]\n",
")\n",
"\n",
"question_router = route_prompt | structured_llm_router"
]
},
{
"cell_type": "markdown",
"id": "c9e4d831",
"metadata": {},
"source": [
"Test Query if it is routed to Web Search or VectorStore"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "595a16db",
"metadata": {},
"outputs": [],
"source": [
"question_router.invoke({\"question\": \"What is the main features of AutoGen?\"})"
]
},
{
"cell_type": "markdown",
"id": "5fc43b99",
"metadata": {},
"source": [
"### Question-Retrieval Grader"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "814af453",
"metadata": {},
"outputs": [],
"source": [
"class GradeDocuments(BaseModel):\n",
" \"\"\"Binary score for relevance check on retrieved documents.\"\"\"\n",
"\n",
" binary_score: str = Field(\n",
" description=\"Documents are relevant to the question, 'yes' or 'no'\"\n",
" )\n",
"\n",
"\n",
"# Custom class based evaluator for grading documents\n",
"class GradeDocumentsEvaluator:\n",
" system = \"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n \n",
" If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \\n\n",
" It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \\n\n",
" Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\"\"\"\n",
"\n",
" def __init__(self, llm_client):\n",
" self.llm_client = llm_client\n",
"\n",
" def __call__(self, *, question: str, document: str, **kwargs):\n",
" grade_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", self.system),\n",
" (\n",
" \"human\",\n",
" \"Retrieved document: \\n\\n {document} \\n\\n User question: {question}\",\n",
" ),\n",
" ]\n",
" )\n",
" structured_llm_grader = self.llm_client.with_structured_output(GradeDocuments)\n",
" retrieval_grader = grade_prompt | structured_llm_grader\n",
" results = retrieval_grader.invoke({\"question\": question, \"document\": document})\n",
" return results\n",
"\n",
"\n",
"document_evaluator = GradeDocumentsEvaluator(llm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2fa5e0d7",
"metadata": {},
"outputs": [],
"source": [
"question = \"What is the main features of AutoGen?\"\n",
"docs = pdf_retriever.invoke(question)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef397b71",
"metadata": {},
"outputs": [],
"source": [
"retrieved_doc = docs[0].page_content\n",
"print(f\"[Retrieved Doc sample]\\n{retrieved_doc}\\n\")\n",
"print(document_evaluator(question=question, document=retrieved_doc))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dce41bfd",
"metadata": {},
"outputs": [],
"source": [
"filtered_docs = []\n",
"for doc in docs:\n",
" result = document_evaluator(question=question, document=doc.page_content)\n",
" if result.binary_score == \"yes\":\n",
" filtered_docs.append(doc)"
]
},
{
"cell_type": "markdown",
"id": "54dce7a1",
"metadata": {},
"source": [
"### Answer Generator\n",
"\n",
"Construct a LLM Generation node. This is a Naive RAG chain that generates an answer based on the retrieved documents. \n",
"\n",
"We recommend you to use more advanced RAG chain for production"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "992ef15a",
"metadata": {},
"outputs": [],
"source": [
"from langchain import hub\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import load_prompt\n",
"\n",
"if has_langchain_key:\n",
" print(f\"Load prompt from LangChain Hub.\")\n",
" prompt = hub.pull(\"daekeun-ml/rag-baseline\")\n",
"else:\n",
" print(\"LANGCHAIN_API_KEY is not set. Load prompt from YAML file.\")\n",
" prompt = load_prompt(\"prompts/rag-baseline.yaml\")\n",
"\n",
"\n",
"def format_docs(docs):\n",
" return \"\\n\\n\".join(\n",
" [\n",
" f'<document><content>{doc.page_content}</content><source>{doc.metadata[\"source\"]}</source><page>{doc.metadata[\"page\"]+1}</page></document>'\n",
" for doc in docs\n",
" ]\n",
" )\n",
"\n",
"\n",
"rag_chain = prompt | llm | StrOutputParser()\n",
"generation = rag_chain.invoke({\"context\": format_docs(docs), \"question\": question})\n",
"print(generation)"
]
},
{
"cell_type": "markdown",
"id": "a0e9f601",
"metadata": {},
"source": [
"### Groundedness Evaluator\n",
"\n",
"We can utilize Azure AI Evaluation API to evaluate the groundedness of the answer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9340878",
"metadata": {},
"outputs": [],
"source": [
"def get_groundedness_score(context, response):\n",
"\n",
" groundedness_eval = GroundednessEvaluator(model_config)\n",
" query_response = dict(\n",
" context=context,\n",
" response=response,\n",
" )\n",
"\n",
" # Running Groundedness Evaluator on a query and response pair\n",
" groundedness_score = groundedness_eval(**query_response)\n",
" return groundedness_score"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4566fcc0",
"metadata": {},
"outputs": [],
"source": [
"display(get_groundedness_score(context=docs, response=generation))"
]
},
{
"cell_type": "markdown",
"id": "395f56cd",
"metadata": {},
"source": [
"### Answer Grader\n",
"\n",
"We can utilize Azure AI Evaluation API to evaluate the relevance of the answer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a112fc4d",
"metadata": {},
"outputs": [],
"source": [
"def get_answer_relevace_score(query, response):\n",
"\n",
" relevance_eval = RelevanceEvaluator(model_config)\n",
" query_response = dict(\n",
" query=query,\n",
" response=response,\n",
" )\n",
" relevance_score = relevance_eval(**query_response)\n",
" return relevance_score"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a50e8c9f",
"metadata": {},
"outputs": [],
"source": [
"display(get_answer_relevace_score(query=question, response=generation))"
]
},
{
"cell_type": "markdown",
"id": "a9fc11dd",
"metadata": {},
"source": [
"### Question Re-writer\n",
"\n",
"Construct a `question_rewriter` node to rewrite the question based on the retrieved documents and the generated answer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9df325a",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"\n",
"system = \"\"\"You a question re-writer that converts an input question to a better version that is optimized \\n \n",
"for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.\"\"\"\n",
"\n",
"re_write_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", system),\n",
" (\n",
" \"human\",\n",
" \"Here is the initial question: \\n\\n {question} \\n Formulate an improved question.\",\n",
" ),\n",
" ]\n",
")\n",
"\n",
"question_rewriter = re_write_prompt | llm | StrOutputParser()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6eb92e7",
"metadata": {},
"outputs": [],
"source": [
"print(f\"[Original question] {question}\")\n",
"question_rewriter.invoke({\"question\": question})"
]
},
{
"cell_type": "markdown",
"id": "d8d5ee42",
"metadata": {},
"source": [
"### Web Search Tool\n",
"\n",
"Web search tool is used to enhance the context."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a383ae6",
"metadata": {},
"outputs": [],
"source": [
"from azure_genai_utils.tools import BingSearch\n",
"\n",
"WEB_SEARCH_FORMAT_OUTPUT = False\n",
"\n",
"web_search_tool = BingSearch(\n",
" max_results=3,\n",
" locale=\"en-US\",\n",
" include_news=True,\n",
" include_entity=False,\n",
" format_output=WEB_SEARCH_FORMAT_OUTPUT,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8851c837",
"metadata": {},
"outputs": [],
"source": [
"results = web_search_tool.invoke({\"query\": question})\n",
"print(results[0])"
]
},
{
"cell_type": "markdown",
"id": "1ac37855",
"metadata": {},
"source": [
"<br>\n",
"\n",
"## 🧪 Step 2. Define the Graph\n",
"---\n",
"\n",
"### State Definition\n",
"\n",
"- `question`: Question from the user\n",
"- `generation`: Generated answer\n",
"- `documents`: Retrieved documents"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d23ab6f",
"metadata": {},
"outputs": [],
"source": [
"from typing import List\n",
"from typing_extensions import TypedDict, Annotated\n",
"\n",
"\n",
"class GraphState(TypedDict):\n",
" question: Annotated[str, \"User question\"]\n",
" generation: Annotated[str, \"LLM generated answer\"]\n",
" documents: Annotated[List[str], \"List of documents\"]"
]
},
{
"cell_type": "markdown",
"id": "f266cc42",
"metadata": {},
"source": [
"### Define Nodes\n",
"\n",
"We will define the following nodes in the graph:\n",
"\n",
"- `retrieve`: Retrieve documents based on the user question.\n",
"- `generate`: Grade documents based on their relevance to the user question.\n",
"- `grade_documents`: Generate an answer based on the retrieved documents and user question.\n",
"- `rewrite_query`: Rewrite the user question to improve retrieval performance.\n",
"- `web_search`: Search the web for additional information."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee6f34d0",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.documents import Document\n",
"\n",
"\n",
"def retrieve(state: GraphState):\n",
" \"\"\"\n",
" Retrieve documents based on the user question.\n",
" \"\"\"\n",
" print(\"\\n==== [RETRIEVE] ====\\n\")\n",
" question = state[\"question\"]\n",
"\n",
" documents = pdf_retriever.invoke(question)\n",
" return {\"documents\": documents}\n",
"\n",
"\n",
"def generate(state: GraphState):\n",
" \"\"\"Generate an answer based on the retrieved documents and user question.\"\"\"\n",
" print(\"\\n==== [GENERATE] ====\\n\")\n",
" question = state[\"question\"]\n",
" documents = state[\"documents\"]\n",
"\n",
" generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n",
" return {\"generation\": generation}\n",
"\n",
"\n",
"def grade_documents(state: GraphState):\n",
" \"\"\"Grade documents based on their relevance to the user question.\"\"\"\n",
" print(\"==== [CHECK DOCUMENT RELEVANCE TO QUESTION] ====\")\n",
" question = state[\"question\"]\n",
" documents = state[\"documents\"]\n",
"\n",
" filtered_docs = []\n",
" for d in documents:\n",
" score = document_evaluator(question=question, document=d.page_content)\n",
" grade = score.binary_score\n",
" if grade == \"yes\":\n",
" print(\"---GRADE: DOCUMENT RELEVANT---\")\n",
" # Add related documents to filtered_docs\n",
" filtered_docs.append(d)\n",
" else:\n",
" print(\"---GRADE: DOCUMENT NOT RELEVANT---\")\n",
" continue\n",
" return {\"documents\": filtered_docs}\n",
"\n",
"\n",
"def rewrite_query(state: GraphState):\n",
" \"\"\"Rewrite the user question to improve web search results\"\"\"\n",
" print(\"\\n==== [REWRITE QUERY] ====\\n\")\n",
" question = state[\"question\"]\n",
"\n",
" better_question = question_rewriter.invoke({\"question\": question})\n",
" return {\"question\": better_question}\n",
"\n",
"\n",
"def web_search(state: GraphState):\n",
" \"\"\"Search the web for additional information.\"\"\"\n",
" print(\"==== [WEB SEARCH] ====\")\n",
" question = state[\"question\"]\n",
"\n",
" web_results = web_search_tool.invoke({\"query\": question})\n",
" web_results_docs = [\n",
" Document(\n",
" page_content=web_result[\"content\"],\n",
" metadata={\"source\": web_result[\"url\"]},\n",
" )\n",
" for web_result in web_results\n",
" ]\n",
"\n",
" return {\"documents\": web_results_docs}"
]
},
{
"cell_type": "markdown",
"id": "069ac4e8",
"metadata": {},
"source": [
"### Define Conditional Nodes\n",
"\n",
"- `route_query`: Route the user question to the most relevant datasource such as vectorstore or web search.\n",
"- `decide_to_generate`: Decide whether to generate an answer or not.\n",
"- `hallucination_check`: Evaluate whether the generated answer is grounded in the retrieved documents."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d33976b9",
"metadata": {},
"outputs": [],
"source": [
"def route_question(state: GraphState):\n",
" \"\"\"Route the user question to the most relevant datasource such as vectorstore or web search.\"\"\"\n",
" print(\"==== [ROUTE QUESTION] ====\")\n",
" question = state[\"question\"]\n",
" source = question_router.invoke({\"question\": question})\n",
" if source.datasource == \"web_search\":\n",
" print(\"==== [ROUTE QUESTION TO WEB SEARCH] ====\")\n",
" return \"web_search\"\n",
" elif source.datasource == \"vectorstore\":\n",
" print(\"==== [ROUTE QUESTION TO VECTORSTORE] ====\")\n",
" return \"vectorstore\"\n",
"\n",
"\n",
"def decide_to_generate(state: GraphState):\n",
" \"\"\"Return the decision to generate an answer or rewrite the question.\"\"\"\n",
" print(\"==== [DECISION TO GENERATE] ====\")\n",
" filtered_documents = state[\"documents\"]\n",
"\n",
" if not filtered_documents:\n",
" print(\n",
" \"==== [DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, REWRITE QUERY] ====\"\n",
" )\n",
" return \"rewrite_query\"\n",
" else:\n",
" print(\"==== [DECISION: GENERATE] ====\")\n",
" return \"generate\"\n",
"\n",
"\n",
"def hallucination_check(state: GraphState):\n",
" \"\"\"Evaluate whether the generated answer is grounded in the retrieved documents.\"\"\"\n",
" print(\"\\n==== [CHECK HALLUCINATIONS] ===\")\n",
" question = state[\"question\"]\n",
" documents = state[\"documents\"]\n",
" generation = state[\"generation\"]\n",
"\n",
" groundedness_score = get_groundedness_score(context=documents, response=generation)\n",
" grade = groundedness_score[\"groundedness\"]\n",
" print(f\"Groundness_score (1-5; higher is better): {grade}\\n\")\n",
"\n",
" if grade >= 4:\n",
" print(\"==== [DECISION: GENERATION IS GROUNDED IN DOCUMENTS] ====\\n\")\n",
" print(\"==== [GRADE GENERATED ANSWER vs QUESTION] ====\")\n",
" relevance_score = get_answer_relevace_score(query=question, response=generation)\n",
" grade = relevance_score[\"relevance\"]\n",
" if grade >= 4:\n",
" print(\n",
" f\"==== [DECISION: GENERATED ANSWER ADDRESSES QUESTION, Relevance Score {grade}] ====\"\n",
" )\n",
" return \"relevant\"\n",
" else:\n",
" print(\"==== [DECISION: GENERATED ANSWER DOES NOT ADDRESS QUESTION] ====\")\n",
" return \"not relevant\"\n",
" else:\n",
" print(\"==== [DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY] ====\")\n",
" return \"hallucination\""
]
},
{
"cell_type": "markdown",
"id": "2412119d",
"metadata": {},
"source": [
"### Construct the Graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c106a028",
"metadata": {},
"outputs": [],
"source": [
"from langgraph.graph import END, StateGraph, START\n",
"from langgraph.checkpoint.memory import MemorySaver\n",
"\n",
"workflow = StateGraph(GraphState)\n",
"\n",
"# Node definition\n",
"workflow.add_node(\"web_search\", web_search)\n",
"workflow.add_node(\"retrieve\", retrieve)\n",
"workflow.add_node(\"grade_documents\", grade_documents)\n",
"workflow.add_node(\"generate\", generate)\n",
"workflow.add_node(\"rewrite_query\", rewrite_query)\n",
"\n",
"# Edge connections\n",
"workflow.add_conditional_edges(\n",
" START,\n",
" route_question,\n",
" {\n",
" \"web_search\": \"web_search\",\n",
" \"vectorstore\": \"retrieve\",\n",
" },\n",
")\n",
"workflow.add_edge(\"web_search\", \"generate\") # Answer generation from web search\n",
"workflow.add_edge(\"retrieve\", \"grade_documents\")\n",
"workflow.add_conditional_edges(\n",
" \"grade_documents\",\n",
" decide_to_generate,\n",
" {\"rewrite_query\": \"rewrite_query\", \"generate\": \"generate\"},\n",
")\n",
"\n",
"# Re-write the query and then retrieve the documents\n",
"workflow.add_edge(\"rewrite_query\", \"retrieve\")\n",
"workflow.add_conditional_edges(\n",
" \"generate\",\n",
" hallucination_check,\n",
" {\n",
" \"hallucination\": \"generate\", # Re-generate the answer if hallucination is detected\n",
" \"relevant\": END, # If the answer is relevant to the question, end the workflow\n",
" \"not relevant\": \"rewrite_query\", # Rewrite the query if the answer is not relevant\n",
" },\n",
")\n",
"\n",
"# Compile the workflow\n",
"app = workflow.compile(checkpointer=MemorySaver())"
]
},
{
"cell_type": "markdown",
"id": "748f4505",
"metadata": {},
"source": [
"### Visualize the graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46ce79fe",
"metadata": {},
"outputs": [],
"source": [
"from azure_genai_utils.graphs import visualize_langgraph\n",
"\n",
"visualize_langgraph(app, xray=True)"
]
},
{
"cell_type": "markdown",
"id": "3fd2739b",
"metadata": {},
"source": [
"<br>\n",
"\n",
"## 🧪 Step 3. Execute the Graph\n",
"---\n",
"\n",
"### Execute the graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b020b140",
"metadata": {},
"outputs": [],
"source": [
"from azure_genai_utils.messages import stream_graph, invoke_graph, random_uuid\n",
"from langchain_core.runnables import RunnableConfig\n",
"\n",
"config = RunnableConfig(recursion_limit=10, configurable={\"thread_id\": random_uuid()})\n",
"\n",
"inputs = {\n",
" \"question\": \"What is AutoGen's main features?\",\n",
"}\n",
"\n",
"stream_graph(app, inputs, config, [\"grade_documents\", \"rewrite\", \"generate\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e25d23b6",
"metadata": {},
"outputs": [],
"source": [
"inputs = {\n",
" \"question\": \"Who is Satya Nadella?\",\n",
"}\n",
"\n",
"stream_graph(app, inputs, config, [\"agent\", \"rewrite\", \"generate\"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv_agent",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}