# Self-RAG with Azure AI Evaluation SDK
---

In this notebook, we will demonstrate how to use the Azure AI Evaluation SDK.


> âœ¨ ***Note*** <br>
> 1. Please check the reference document before you get started - https://learn.microsoft.com/en-us/azure/ai-studio/how-to/develop/evaluate-sdk <br>
> 2. Check the region support for the Azure AI Evaluation SDK. https://learn.microsoft.com/en-us/azure/ai-studio/concepts/evaluation-metrics-built-in?tabs=warning#region-support

### What is Self-RAG?

Self-RAG reflects on the retrieved documents and generated responses, and includes a self-evaluation process to improve the quality of the generated answers.

Original paper says Self-RAG generates special tokens, termed "reflection tokens," to determine if retrieval would enhance the response, allowing for on-demand retrieval integration. 
But in practice, we can ignore reflection tokens and let LLM decides if each document is relevant or not.

Corrective RAG (CRAG) is similar to Self-RAG, but Self-RAG focuses on self-reflection and self-evaluation, while CRAG focuses on refining the entire retrieval process including web search.

- Self-RAG: Trains the LLM to be self-sufficient in managing retrieval and generation processes. By generating reflection tokens, the model controls its behavior during inference, deciding when to retrieve information and how to critique and improve its own responses, leading to more accurate and contextually appropriate outputs. 
- CRAG: Focuses on refining the retrieval process by evaluating and correcting the retrieved documents before they are used in generation. It integrates additional retrievals, such as web searches, when initial retrievals are insufficient, ensuring that the generation is based on the most relevant and accurate information available.

**Reference**

- [Self-RAG paper](https://arxiv.org/abs/2310.11511)  

In [None]:
import os
from dotenv import load_dotenv
from azure_genai_utils.tracer import get_langchain_api_key, set_langsmith

load_dotenv(override=True)

# If you want to trace your RAG API calls, please set the tracing=True. You need to have a valid Langchain API key.
langchain_key, has_langchain_key = get_langchain_api_key()
set_langsmith("[RAG Innv Lab] 1_Agentic-Design-Pattern", tracing=False)

azure_openai_chat_deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME")
azure_openai_embedding_deployment_name = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME")

In [None]:
import os
import pprint
from azure.identity import DefaultAzureCredential
from azure.ai.evaluation import evaluate
from azure.ai.evaluation import (
    ContentSafetyEvaluator,
    RelevanceEvaluator,
    CoherenceEvaluator,
    GroundednessEvaluator,
    FluencyEvaluator,
    SimilarityEvaluator,
    F1ScoreEvaluator,
    RetrievalEvaluator,
)

credential = DefaultAzureCredential()

# Initialize Azure OpenAI conncetion with your environment variables

model_config = {
    "azure_endpoint": os.environ.get("AZURE_OPENAI_ENDPOINT"),
    "api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
    "azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME"),
    "api_version": os.environ.get("AZURE_OPENAI_API_VERSION"),
    "type": "azure_openai",
}

pprint.pprint(model_config)

<br>

## ðŸ§ª Step 1. Test and Construct each module
---

Before building the entire the graph pipeline, we will test and construct each module separately.

- **Retrieval Grader**
- **Answer Generator**
- **Groundedness Evaluator**
- **Relevance Evaluator**
- **Question Re-writer**

### Construct Retrieval Chain based on PDF

In [None]:
from azure_genai_utils.rag.pdf import PDFRetrievalChain

pdf_path = "../../../sample-docs/AutoGen-paper.pdf"

pdf = PDFRetrievalChain(
    source_uri=[pdf_path],
    loader_type="PDFPlumber",
    model_name=azure_openai_chat_deployment_name,
    embedding_name=azure_openai_embedding_deployment_name,
    chunk_size=500,
    chunk_overlap=50,
).create_chain()

pdf_retriever = pdf.retriever
pdf_chain = pdf.chain

question = "What is AutoGen's main features?"
docs = pdf_retriever.invoke(question)

# Non-streaming
# results = pdf_chain.invoke({"chat_history": "", "question": question, "context": docs})

# Streaming
for text in pdf_chain.stream(
    {"chat_history": "", "question": question, "context": docs}
):
    print(text, end="", flush=True)

### Define your LLM

This hands-on only uses the `gpt-4o-mini`, but you can utilize multiple models in the pipeline.

In [None]:
from langchain_openai import AzureChatOpenAI

llm = AzureChatOpenAI(model=azure_openai_chat_deployment_name, temperature=0)

### Question-Retrieval Grader

Construct a retrieval grader that evaluates the relevance of the retrieved documents to the input question. The retrieval grader should take the input question and the retrieved documents as input and output a relevance score for each document.<br>
Note that the retrieval grader should be able to handle **multiple documents** as input.

In [None]:
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate


class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    binary_score: str = Field(
        description="Documents are relevant to the question, 'yes' or 'no'"
    )


# Custom class based evaluator for grading documents
class GradeDocumentsEvaluator:
    system = """You are a grader assessing relevance of a retrieved document to a user question. \n 
        If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
        It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""

    def __init__(self, llm_client):
        self.llm_client = llm_client

    def __call__(self, *, question: str, document: str, **kwargs):
        grade_prompt = ChatPromptTemplate.from_messages(
            [
                ("system", self.system),
                (
                    "human",
                    "Retrieved document: \n\n {document} \n\n User question: {question}",
                ),
            ]
        )
        structured_llm_grader = self.llm_client.with_structured_output(GradeDocuments)
        retrieval_grader = grade_prompt | structured_llm_grader
        results = retrieval_grader.invoke({"question": question, "document": document})
        return results


document_evaluator = GradeDocumentsEvaluator(llm)

Test the retrieval grader. For testing, we only show the result of the a single document, not the entire document set. 

In [None]:
question = "What is AutoGen's main features?"
docs = pdf_retriever.invoke(question)

retrieved_doc = docs[0].page_content
print(f"[Retrieved Doc sample]\n{retrieved_doc}\n")
print(document_evaluator(question=question, document=retrieved_doc))

### Answer Generator

Construct a LLM Generation node. This is a Naive RAG chain that generates an answer based on the retrieved documents. 

We recommend you to use more advanced RAG chain for production

In [None]:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import load_prompt

if has_langchain_key:
    print(f"Load prompt from LangChain Hub.")
    prompt = hub.pull("daekeun-ml/rag-baseline")
else:
    print("LANGCHAIN_API_KEY is not set. Load prompt from YAML file.")
    prompt = load_prompt("prompts/rag-baseline.yaml")


def format_docs(docs):
    return "\n\n".join(
        [
            f'<document><content>{doc.page_content}</content><source>{doc.metadata["source"]}</source><page>{doc.metadata["page"]+1}</page></document>'
            for doc in docs
        ]
    )


rag_chain = prompt | llm | StrOutputParser()
generation = rag_chain.invoke({"context": format_docs(docs), "question": question})
print(generation)

### Groundedness Evaluator

We can utilize Azure AI Evaluation API to evaluate the groundedness of the answer.

In [None]:
def get_groundedness_score(context, response):

    groundedness_eval = GroundednessEvaluator(model_config)
    query_response = dict(
        context=context,
        response=response,
    )

    # Running Groundedness Evaluator on a query and response pair
    groundedness_score = groundedness_eval(**query_response)
    return groundedness_score

In [None]:
display(get_groundedness_score(context=format_docs(docs), response=generation))

### Relevance Evaluator

We can utilize Azure AI Evaluation API to evaluate the relevance of the answer.

In [None]:
def get_answer_relevace_score(query, response):

    relevance_eval = RelevanceEvaluator(model_config)
    query_response = dict(
        query=query,
        response=response,
    )
    relevance_score = relevance_eval(**query_response)
    return relevance_score

In [None]:
display(get_answer_relevace_score(query=question, response=generation))

### Question Re-writer

Construct a `question_rewriter` node to rewrite the question based on the retrieved documents and the generated answer.

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

system = """You a question re-writer that converts an input question to a better version that is optimized
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""

re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {question} \n Formulate an improved question.",
        ),
    ]
)

question_rewriter = re_write_prompt | llm | StrOutputParser()

In [None]:
print(f"[Original question] {question}")
question_rewriter.invoke({"question": question})

<br>

## ðŸ§ª Step 2. Define the Graph
---

### State Definition

- `question`: Question from the user
- `generation`: Generated answer
- `documents`: Retrieved documents

In [None]:
from typing import List
from typing_extensions import TypedDict, Annotated


class GraphState(TypedDict):
    question: Annotated[str, "Question"]
    generation: Annotated[str, "LLM Generation"]
    documents: Annotated[List[str], "Retrieved Documents"]

### Define Nodes

We will define the following nodes in the graph:

- `retrieve`: Retrieve documents based on the user question.
- `grade_documents`: Generate an answer based on the retrieved documents and user question.
- `generate`: Grade documents based on their relevance to the user question.
- `rewrite_query`: Rewrite the user question to improve retrieval performance.


In [None]:
def retrieve(state: GraphState):
    """
    Retrieve documents based on the user question.
    """
    print("==== [RETRIEVE] ====")
    question = state["question"]

    documents = pdf_retriever.invoke(question)
    return {"documents": documents}


def generate(state: GraphState):
    """Generate an answer based on the retrieved documents and user question."""
    print("==== [GENERATE] ====")
    question = state["question"]
    documents = state["documents"]

    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"generation": generation}


def grade_documents(state: GraphState):
    """Grade documents based on their relevance to the user question."""
    print("==== [GRADE DOCUMENTS] ====")
    question = state["question"]
    documents = state["documents"]

    filtered_docs = []
    relevant_doc_count = 0

    for d in documents:
        score = document_evaluator(question=question, document=d.page_content)
        grade = score.binary_score
        if grade == "yes":
            # Add related documents to filtered_docs
            print("==== GRADE: DOCUMENT RELEVANT ====")
            filtered_docs.append(d)
            relevant_doc_count += 1
        else:
            print("==== GRADE: DOCUMENT NOT RELEVANT ====")
            continue
    return {"documents": filtered_docs}


def rewrite_query(state: GraphState):
    """Rewrite the user question to improve retrieval performance."""
    print("\n==== [REWRITE QUERY] ====\n")
    question = state["question"]

    better_question = question_rewriter.invoke({"question": question})
    return {"question": better_question}

### Define Conditional Nodes

- `decide_to_generate`: Decide whether to generate an answer based on the retrieved documents. 
- `grade_generation_v_documents_and_question`: Grade the generated answer based on its relevance to the user question and the retrieved documents.

In [None]:
def decide_to_generate(state):
    """
    Assess whether to generate an answer based on the relevance of the retrieved documents to the user question
    """
    print("==== [ASSESS GRADED DOCUMENTS] ====")
    state["question"]
    filtered_documents = state["documents"]

    if not filtered_documents:
        # If all documents are not relevant to the question, rewrite the query
        print(
            "==== [DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, REWRITE QUERY] ===="
        )
        return "rewrite_query"
    else:
        # If there are relevant documents, generate an answer
        print("==== [DECISION: GENERATE] ====")
        return "generate"


def grade_generation_v_documents_and_question(state):
    """
    Grade the relevance of the generated answer to the user question and retrieved documents.
    """
    print("==== [CHECK HALLUCINATIONS] ====")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    groundedness_score = get_groundedness_score(context=documents, response=generation)
    grade = groundedness_score["groundedness"]
    print(f"Groundness_score (1-5; higher is better): {grade}\n")

    # Groundedness check
    if grade >= 4:
        print("==== [DECISION: GENERATION IS GROUNDED IN DOCUMENTS] ====")
        print("==== [GRADE GENERATION vs QUESTION] ====")
        relevance_score = get_answer_relevace_score(query=question, response=generation)
        grade = relevance_score["relevance"]
        if grade >= 4:
            print(
                f"==== [DECISION: GENERATED ANSWER ADDRESSES QUESTION, Relevance Score {grade}] ===="
            )
            return "relevant"
        else:
            print("==== [DECISION: GENERATED ANSWER DOES NOT ADDRESS QUESTION] ====")
            return "not relevant"
    else:
        print("==== [DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY] ====")
        return "hallucination"

### Construct the Graph

In [None]:
from langgraph.graph import END, StateGraph, START
from langgraph.checkpoint.memory import MemorySaver

workflow = StateGraph(GraphState)

# Node definition
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("rewrite_query", rewrite_query)

# Edge connections
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "rewrite_query": "rewrite_query",
        "generate": "generate",
    },
)
workflow.add_edge("rewrite_query", "retrieve")
workflow.add_conditional_edges(
    "generate",
    grade_generation_v_documents_and_question,
    {
        "hallucination": "generate",
        "relevant": END,
        "not relevant": "rewrite_query",
    },
)

# Compile the workflow
app = workflow.compile(checkpointer=MemorySaver())

### Visualize the graph

In [None]:
from azure_genai_utils.graphs import visualize_langgraph

visualize_langgraph(app, xray=True)

<br>

## ðŸ§ª Step 3. Execute the Graph
---

### Execute the graph

In [None]:
from langchain_core.runnables import RunnableConfig
from langgraph.errors import GraphRecursionError
from azure_genai_utils.messages import stream_graph, invoke_graph, random_uuid

config = RunnableConfig(recursion_limit=10, configurable={"thread_id": random_uuid()})

inputs = {
    "question": "What is AutoGen's main features?",
}

try:
    stream_graph(
        app,
        inputs,
        config,
        ["retrieve", "rewrite_query", "grade_documents", "generate"],
    )
except GraphRecursionError as recursion_error:
    print(f"GraphRecursionError: {recursion_error}")

### Define the Failure Condition

The below execution graph shows a recursive state where the graph keeps generating answers for non-related questions without providing a satisfactory response to the user.<br>
To prevent this, you can define a web search node that searches for related questions and provides a list of related questions to the user.

Corrective-RAG (CRAG) is a similar approach that focuses on refining the entire retrieval process, including web search, to ensure that the generation is based on the most relevant and accurate information available.

In [None]:
from langchain_core.runnables import RunnableConfig
from langgraph.errors import GraphRecursionError
from azure_genai_utils.messages import stream_graph, invoke_graph, random_uuid

config = RunnableConfig(recursion_limit=10, configurable={"thread_id": random_uuid()})

inputs = {
    "question": "Who is Daekeun?",
}

try:
    stream_graph(
        app,
        inputs,
        config,
        ["retrieve", "rewrite_query", "grade_documents", "generate"],
    )
except GraphRecursionError as recursion_error:
    print(f"GraphRecursionError: {recursion_error}")