# Adaptive RAG

----

Adaptive RAG predicts the **complexity of the input question** using a SLM/LLM and selects an appropriate processing workflow accordingly.

- **Very simple question (No Retrieval)**: Generates answers without RAG.
- **Simple question (Single-shot RAG)**: Efficiently generates answers through a single-step search and generation.
- **Complex question (Iterative RAG)**: Provides accurate answers to complex questions through repeated multi-step search and generation.


Adaptive-RAG, Self-RAG, and Corrective RAG are similar approach, but they have different focuses.

- **Adaptive-RAG**: Dynamically selects appropriate retrieval and generation strategies based on the complexity of the question.
- **Self-RAG**: The model determines the need for retrieval on its own, performs retrieval when necessary, and improves the quality through self-reflection on the generated answers.
- **Corrective RAG**: Evaluates the quality of retrieved documents, and performs additional retrievals such as web searches to supplement the information if the reliability is low.

**Reference**

- [Adaptive-RAG paper](https://arxiv.org/abs/2403.14403)  

In [None]:
import os
from dotenv import load_dotenv
from azure_genai_utils.tracer import get_langchain_api_key, set_langsmith

load_dotenv(override=True)

# If you want to trace your RAG API calls, please set the tracing=True. You need to have a valid Langchain API key.
langchain_key, has_langchain_key = get_langchain_api_key()
set_langsmith("[RAG Innv Lab] 1_Agentic-Design-Pattern", tracing=False)

azure_openai_chat_deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME")
azure_openai_embedding_deployment_name = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME")

<br>

## ðŸ§ª Step 1. Test and Construct each module
---

### Construct Retrieval Chain based on PDF

In [None]:
from azure_genai_utils.rag.pdf import PDFRetrievalChain

pdf_path = "../../../sample-docs/AutoGen-paper.pdf"

pdf = PDFRetrievalChain(
    source_uri=[pdf_path],
    loader_type="PDFPlumber",
    model_name=azure_openai_chat_deployment_name,
    embedding_name=azure_openai_embedding_deployment_name,
    chunk_size=500,
    chunk_overlap=50,
).create_chain()

pdf_retriever = pdf.retriever
pdf_chain = pdf.chain

question = "What is AutoGen's main features?"
docs = pdf_retriever.invoke(question)

# Non-streaming
# results = pdf_chain.invoke({"chat_history": "", "question": question, "context": docs})

# Streaming
for text in pdf_chain.stream(
    {"chat_history": "", "question": question, "context": docs}
):
    print(text, end="", flush=True)

### Query Routing and Document Evaluation

Adaptive RAG performs query routing and document evaluation to provide accurate and reliable information. This process is essential for maximizing the performance of LLMs.

- **Query Routing**: Analyze user queries to route them to appropriate information sources. This allows you to set the optimal search path for the purpose of the query.
- **Document Evaluation**: Evaluate the quality and relevance of retrieved documents to increase the accuracy of the final results.

In [None]:
from typing import Literal

from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from langchain_openai import AzureChatOpenAI



class RouteQuery(BaseModel):
    """Route a user query to the most relevant datasource."""

    datasource: Literal["vectorstore", "web_search"] = Field(
        ...,
        description="Given a user question choose to route it to web search or a vectorstore.",
    )


llm = AzureChatOpenAI(model=azure_openai_chat_deployment_name, temperature=0)
structured_llm_router = llm.with_structured_output(RouteQuery)

TOPIC = "AutoGen"
system = f"""You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to {TOPIC}.
Use the vectorstore for questions on these topics. Otherwise, use web-search."""


route_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{question}"),
    ]
)

question_router = route_prompt | structured_llm_router

Test Query if it is routed to Web Search or VectorStore

In [None]:
question_router.invoke("Who is Satya Nadella?")

In [None]:
question_router.invoke({"question": "What is the main features of AutoGen?"})

### Question-Retrieval Grader

In [None]:
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate


class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    binary_score: str = Field(
        description="Documents are relevant to the question, 'yes' or 'no'"
    )


structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Prompt template with system message and user question
system = """You are a grader assessing relevance of a retrieved document to a user question. \n 
    If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
    It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""

grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

retrieval_grader = grade_prompt | structured_llm_grader

In [None]:
question = "What is the main features of AutoGen?"
docs = pdf_retriever.invoke(question)

In [None]:
retrieved_doc = docs[0].page_content
print(f"[Retrieved Doc sample]\n{retrieved_doc}\n")
print(retrieval_grader.invoke({"question": question, "document": retrieved_doc}))

In [None]:
filtered_docs = []
for doc in docs:
    result = retrieval_grader.invoke(
        {
            "question": question,
            "document": doc.page_content,
        }
    )
    if result.binary_score == "yes":
        filtered_docs.append(doc)

### Answer Generator

Construct a LLM Generation node. This is a Naive RAG chain that generates an answer based on the retrieved documents. 

We recommend you to use more advanced RAG chain for production

In [None]:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import load_prompt

if has_langchain_key:
    print(f"Load prompt from LangChain Hub.")
    prompt = hub.pull("daekeun-ml/rag-baseline")
else:
    print("LANGCHAIN_API_KEY is not set. Load prompt from YAML file.")
    prompt = load_prompt("prompts/rag-baseline.yaml")


def format_docs(docs):
    return "\n\n".join(
        [
            f'<document><content>{doc.page_content}</content><source>{doc.metadata["source"]}</source><page>{doc.metadata["page"]+1}</page></document>'
            for doc in docs
        ]
    )


rag_chain = prompt | llm | StrOutputParser()
generation = rag_chain.invoke({"context": format_docs(docs), "question": question})
print(generation)

### Groundedness Evaluator

Construct a `groundedness_grader` node to evaluate the **hallucination** of the generated answer based on the retrieved documents.<br>

`yes` means the answer is relevant to the retrieved documents, and `no` means the answer is not relevant to the retrieved documents.

In [None]:
class Groundednesss(BaseModel):
    """Binary score for hallucination present in generation answer."""

    binary_score: str = Field(
        description="Answer is grounded in the facts, 'yes' or 'no'"
    )


structured_llm_grader = llm.with_structured_output(Groundednesss)

system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n 
    Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""

groundedness_checking_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Set of facts: \n\n {documents} \n\n LLM Generated Answer: {generation}",
        ),
    ]
)
groundedness_grader = groundedness_checking_prompt | structured_llm_grader

In [None]:
groundedness_grader.invoke({"documents": docs, "generation": generation})

### Answer Grader

Construct a `answer_grader` node to evaluate the quality of the generated answer based on the retrieved documents.

In [None]:
class GradeAnswer(BaseModel):
    """Binary scoring to evaluate the appropriateness of answers to questions"""

    binary_score: str = Field(
        description="Indicate 'yes' or 'no' whether the answer solves the question"
    )


structured_llm_grader = llm.with_structured_output(GradeAnswer)

system = """You are a grader assessing whether an answer addresses / resolves a question \n 
    Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
    ]
)

answer_grader = answer_prompt | structured_llm_grader

In [None]:
answer_grader.invoke({"question": question, "generation": generation})

### Question Re-writer

Construct a `question_rewriter` node to rewrite the question based on the retrieved documents and the generated answer.

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

system = """You a question re-writer that converts an input question to a better version that is optimized \n 
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""

re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {question} \n Formulate an improved question.",
        ),
    ]
)

question_rewriter = re_write_prompt | llm | StrOutputParser()

In [None]:
print(f"[Original question] {question}")
question_rewriter.invoke({"question": question})

### Web Search Tool

Web search tool is used to enhance the context.

In [None]:
from azure_genai_utils.tools import BingSearch

WEB_SEARCH_FORMAT_OUTPUT = False

web_search_tool = BingSearch(
    max_results=3,
    locale="en-US",
    include_news=True,
    include_entity=False,
    format_output=WEB_SEARCH_FORMAT_OUTPUT,
)

In [None]:
results = web_search_tool.invoke({"query": question})
print(results[0])

<br>

## ðŸ§ª Step 2. Define the Graph
---

### State Definition

- `question`: Question from the user
- `generation`: Generated answer
- `documents`: Retrieved documents

In [None]:
from typing import List
from typing_extensions import TypedDict, Annotated


class GraphState(TypedDict):
    question: Annotated[str, "User question"]
    generation: Annotated[str, "LLM generated answer"]
    documents: Annotated[List[str], "List of documents"]

### Define Nodes

We will define the following nodes in the graph:

- `retrieve`: Retrieve documents based on the user question.
- `generate`: Grade documents based on their relevance to the user question.
- `grade_documents`: Generate an answer based on the retrieved documents and user question.
- `rewrite_query`: Rewrite the user question to improve retrieval performance.
- `web_search`: Search the web for additional information.

In [None]:
from langchain_core.documents import Document


def retrieve(state: GraphState):
    """
    Retrieve documents based on the user question.
    """
    print("\n==== [RETRIEVE] ====\n")
    question = state["question"]

    documents = pdf_retriever.invoke(question)
    return {"documents": documents}


def generate(state: GraphState):
    """Generate an answer based on the retrieved documents and user question."""
    print("\n==== [GENERATE] ====\n")
    question = state["question"]
    documents = state["documents"]

    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"generation": generation}


def grade_documents(state: GraphState):
    """Grade documents based on their relevance to the user question."""
    print("==== [CHECK DOCUMENT RELEVANCE TO QUESTION] ====")
    question = state["question"]
    documents = state["documents"]

    filtered_docs = []
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            # Add related documents to filtered_docs
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            continue
    return {"documents": filtered_docs}


def rewrite_query(state: GraphState):
    """Rewrite the user question to improve web search results"""
    print("\n==== [REWRITE QUERY] ====\n")
    question = state["question"]

    better_question = question_rewriter.invoke({"question": question})
    return {"question": better_question}


def web_search(state: GraphState):
    """Search the web for additional information."""
    print("==== [WEB SEARCH] ====")
    question = state["question"]

    web_results = web_search_tool.invoke({"query": question})
    web_results_docs = [
        Document(
            page_content=web_result["content"],
            metadata={"source": web_result["url"]},
        )
        for web_result in web_results
    ]

    return {"documents": web_results_docs}

### Define Conditional Nodes

- `route_query`: Route the user question to the most relevant datasource such as vectorstore or web search.
- `decide_to_generate`: Decide whether to generate an answer or not.
- `hallucination_check`: Evaluate whether the generated answer is grounded in the retrieved documents.

In [None]:
def route_question(state: GraphState):
    """Route the user question to the most relevant datasource such as vectorstore or web search."""
    print("==== [ROUTE QUESTION] ====")
    question = state["question"]
    source = question_router.invoke({"question": question})
    if source.datasource == "web_search":
        print("==== [ROUTE QUESTION TO WEB SEARCH] ====")
        return "web_search"
    elif source.datasource == "vectorstore":
        print("==== [ROUTE QUESTION TO VECTORSTORE] ====")
        return "vectorstore"


def decide_to_generate(state: GraphState):
    """Return the decision to generate an answer or rewrite the question."""
    print("==== [DECISION TO GENERATE] ====")
    filtered_documents = state["documents"]

    if not filtered_documents:
        print(
            "==== [DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, REWRITE QUERY] ===="
        )
        return "rewrite_query"
    else:
        print("==== [DECISION: GENERATE] ====")
        return "generate"


def hallucination_check(state: GraphState):
    """Evaluate whether the generated answer is grounded in the retrieved documents."""
    print("\n==== [CHECK HALLUCINATIONS] ===")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    score = groundedness_grader.invoke(
        {"documents": documents, "generation": generation}
    )
    grade = score.binary_score

    if grade == "yes":
        print("==== [DECISION: GENERATION IS GROUNDED IN DOCUMENTS] ====\n")
        print("==== [GRADE GENERATED ANSWER vs QUESTION] ====")
        score = answer_grader.invoke({"question": question, "generation": generation})
        grade = score.binary_score

        if grade == "yes":
            print("==== [DECISION: GENERATED ANSWER ADDRESSES QUESTION] ====")
            return "relevant"
        else:
            print("==== [DECISION: GENERATED ANSWER DOES NOT ADDRESS QUESTION] ====")
            return "not relevant"
    else:
        print("==== [DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY] ====")
        return "hallucination"

### Construct the Graph

In [None]:
from langgraph.graph import END, StateGraph, START
from langgraph.checkpoint.memory import MemorySaver

workflow = StateGraph(GraphState)

# Node definition
workflow.add_node("web_search", web_search)
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("rewrite_query", rewrite_query)

# Edge connections
workflow.add_conditional_edges(
    START,
    route_question,
    {
        "web_search": "web_search",
        "vectorstore": "retrieve",
    },
)
workflow.add_edge("web_search", "generate")  # Answer generation from web search
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {"rewrite_query": "rewrite_query", "generate": "generate"},
)

# Re-write the query and then retrieve the documents
workflow.add_edge("rewrite_query", "retrieve")
workflow.add_conditional_edges(
    "generate",
    hallucination_check,
    {
        "hallucination": "generate",  # Re-generate the answer if hallucination is detected
        "relevant": END,  # If the answer is relevant to the question, end the workflow
        "not relevant": "rewrite_query",  # Rewrite the query if the answer is not relevant
    },
)

# Compile the workflow
app = workflow.compile(checkpointer=MemorySaver())

### Visualize the graph

In [None]:
from azure_genai_utils.graphs import visualize_langgraph

visualize_langgraph(app, xray=True)

<br>

## ðŸ§ª Step 3. Execute the Graph
---

### Execute the graph

In [None]:
from azure_genai_utils.messages import stream_graph, invoke_graph, random_uuid
from langchain_core.runnables import RunnableConfig

config = RunnableConfig(recursion_limit=10, configurable={"thread_id": random_uuid()})

inputs = {
    "question": "What is AutoGen's main features?",
}

stream_graph(app, inputs, config, ["grade_documents", "rewrite", "generate"])

In [None]:
inputs = {
    "question": "Who is Satya Nadella?",
}

stream_graph(app, inputs, config, ["agent", "rewrite", "generate"])