2-notebooks/1-chat_completion/3-basic-rag.ipynb (454 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "3cf62fbc", "metadata": {}, "source": [ "# 🍏 Basic Retrieval-Augmented Generation (RAG) with AIProjectClient 🍎\n", "\n", "In this notebook, we'll demonstrate a **basic RAG** flow using:\n", "- **`azure-ai-projects`** (AIProjectClient)\n", "- **`azure-ai-inference`** (Embeddings, ChatCompletions)\n", "- **`azure-ai-search`** (for vector or hybrid search)\n", "\n", "Our theme is **Health & Fitness** 🍏 so we’ll create a simple set of health tips, embed them, store them in a search index, then do a query that retrieves relevant tips, and pass them to an LLM to produce a final answer.\n", "\n", "> **Disclaimer**: This is not medical advice. For real health questions, consult a professional.\n", "\n", "## What is RAG?\n", "Retrieval-Augmented Generation (RAG) is a technique where the LLM (Large Language Model) uses relevant retrieved text chunks from your data to craft a final answer. This helps ground the model's response in real data, reducing hallucinations.\n" ] }, { "cell_type": "markdown", "id": "c26cbaa3", "metadata": {}, "source": [ "<img src=\"./seq-diagrams/3-basic-rag.png\" width=\"30%\"/>" ] }, { "cell_type": "markdown", "id": "99dfbd81", "metadata": {}, "source": [ "## 1. Setup\n", "We'll import libraries, load environment variables, and create an `AIProjectClient`.\n", "\n", "> #### Complete [2-embeddings.ipynb](2-embeddings.ipynb) notebook before starting this one\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e7395b2e", "metadata": { "executionInfo": {} }, "outputs": [], "source": [ "import os\n", "import time\n", "import json\n", "from dotenv import load_dotenv\n", "\n", "# azure-ai-projects\n", "from azure.ai.projects import AIProjectClient\n", "from azure.identity import DefaultAzureCredential\n", "\n", "# We'll embed with azure-ai-inference\n", "from azure.ai.inference import EmbeddingsClient, ChatCompletionsClient\n", "from azure.ai.inference.models import UserMessage, SystemMessage\n", "\n", "# For vector search or hybrid search\n", "from azure.search.documents import SearchClient\n", "from azure.search.documents.indexes import SearchIndexClient\n", "from azure.core.credentials import AzureKeyCredential\n", "from pathlib import Path\n", "\n", "# Load environment variables\n", "notebook_path = Path().absolute()\n", "parent_dir = notebook_path.parent\n", "load_dotenv(parent_dir / '.env')\n", "\n", "conn_string = os.environ.get(\"PROJECT_CONNECTION_STRING\")\n", "chat_model = os.environ.get(\"MODEL_DEPLOYMENT_NAME\", \"gpt-4o-mini\")\n", "embedding_model = os.environ.get(\"EMBEDDING_MODEL_DEPLOYMENT_NAME\", \"text-embedding-3-small\")\n", "search_index_name = os.environ.get(\"SEARCH_INDEX_NAME\", \"healthtips-index\")\n", "\n", "try:\n", " project_client = AIProjectClient.from_connection_string(\n", " credential=DefaultAzureCredential(),\n", " conn_str=conn_string,\n", " )\n", " print(\"✅ AIProjectClient created successfully!\")\n", "except Exception as e:\n", " print(\"❌ Error creating AIProjectClient:\", e)" ] }, { "cell_type": "markdown", "id": "94b84bb8", "metadata": {}, "source": [ "## 2. Create Sample Health Data\n", "We'll create a few short doc chunks. In a real scenario, you might read from CSV or PDFs, chunk them up, embed them, and store them in your search index.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2eab53d6", "metadata": {}, "outputs": [], "source": [ "health_tips = [\n", " {\n", " \"id\": \"doc1\",\n", " \"content\": \"Daily 30-minute walks help maintain a healthy weight and reduce stress.\",\n", " \"source\": \"General Fitness\"\n", " },\n", " {\n", " \"id\": \"doc2\",\n", " \"content\": \"Stay hydrated by drinking 8-10 cups of water per day.\",\n", " \"source\": \"General Fitness\"\n", " },\n", " {\n", " \"id\": \"doc3\",\n", " \"content\": \"Consistent sleep patterns (7-9 hours) improve muscle recovery.\",\n", " \"source\": \"General Fitness\"\n", " },\n", " {\n", " \"id\": \"doc4\",\n", " \"content\": \"For cardio endurance, try interval training like HIIT.\",\n", " \"source\": \"Workout Advice\"\n", " },\n", " {\n", " \"id\": \"doc5\",\n", " \"content\": \"Warm up with dynamic stretches before running to reduce injury risk.\",\n", " \"source\": \"Workout Advice\"\n", " },\n", " {\n", " \"id\": \"doc6\",\n", " \"content\": \"Balanced diets typically include protein, whole grains, fruits, vegetables, and healthy fats.\",\n", " \"source\": \"Nutrition\"\n", " },\n", "]\n", "print(\"Created a small list of health tips.\")" ] }, { "cell_type": "markdown", "id": "d55b8d39", "metadata": {}, "source": [ "## 3.0. Create or Reset the Index\n", "When creating a vector field in Azure AI Search, the **field definition** must include a `vector_search_profile` property that points to a matching profile name in your vector search settings.\n", "\n", "We'll define a helper function to create (or reset) a vector index with an HNSW algorithm config.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "aabd84b8", "metadata": {}, "outputs": [], "source": [ "from azure.search.documents.indexes.models import (\n", " SearchIndex,\n", " SearchField,\n", " SearchFieldDataType,\n", " SimpleField,\n", " SearchableField,\n", " VectorSearch,\n", " HnswAlgorithmConfiguration,\n", " HnswParameters,\n", " VectorSearchAlgorithmKind,\n", " VectorSearchAlgorithmMetric,\n", " VectorSearchProfile,\n", ")\n", "\n", "def create_healthtips_index(\n", " endpoint: str, api_key: str, index_name: str, \n", " dimension: int = 1536 # if using text-embedding-3-small\n", " ):\n", " \"\"\"Create or update a search index for health tips with vector search capability.\"\"\"\n", " \n", " index_client = SearchIndexClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))\n", " \n", " # Try to delete existing index\n", " try:\n", " index_client.delete_index(index_name)\n", " print(f\"Deleted existing index: {index_name}\")\n", " except Exception:\n", " pass # Index doesn't exist yet\n", " \n", " # Define vector search configuration\n", " vector_search = VectorSearch(\n", " algorithms=[\n", " HnswAlgorithmConfiguration(\n", " name=\"myHnsw\",\n", " kind=VectorSearchAlgorithmKind.HNSW,\n", " parameters=HnswParameters(\n", " m=4,\n", " ef_construction=400,\n", " ef_search=500,\n", " metric=VectorSearchAlgorithmMetric.COSINE\n", " )\n", " )\n", " ],\n", " profiles=[\n", " VectorSearchProfile(\n", " name=\"myHnswProfile\",\n", " algorithm_configuration_name=\"myHnsw\"\n", " )\n", " ]\n", " )\n", " \n", " # Define fields\n", " fields = [\n", " SimpleField(name=\"id\", type=SearchFieldDataType.String, key=True),\n", " SearchableField(name=\"content\", type=SearchFieldDataType.String),\n", " SimpleField(name=\"source\", type=SearchFieldDataType.String),\n", " SearchField(\n", " name=\"embedding\", \n", " type=SearchFieldDataType.Collection(SearchFieldDataType.Single),\n", " vector_search_dimensions=dimension,\n", " vector_search_profile_name=\"myHnswProfile\" \n", " ),\n", " ]\n", " \n", " # Create index definition\n", " index_def = SearchIndex(\n", " name=index_name,\n", " fields=fields,\n", " vector_search=vector_search\n", " )\n", " \n", " # Create the index\n", " index_client.create_index(index_def)\n", " print(f\"✅ Created or reset index: {index_name}\")" ] }, { "cell_type": "markdown", "id": "21b4d51a", "metadata": {}, "source": [ "## 3.1. Create Index & Upload Health Tips 🏋️\n", "\n", "Now we'll put our health tips into action by:\n", "1. **Creating a search connection** to Azure AI Search\n", "2. **Building our index** with vector search capability\n", "3. **Generating embeddings** for each health tip\n", "4. **Uploading** the tips with their embeddings\n", "\n", "This creates our knowledge base that we'll search through later. Think of it as building our 'fitness library' that our AI assistant can reference! 📚💪" ] }, { "cell_type": "code", "execution_count": null, "id": "33f2f690", "metadata": {}, "outputs": [], "source": [ "from azure.ai.projects.models import ConnectionType\n", "\n", "# Step 1: Get search connection\n", "search_conn = project_client.connections.get_default(\n", " connection_type=ConnectionType.AZURE_AI_SEARCH, \n", " include_credentials=True\n", ")\n", "if not search_conn:\n", " raise RuntimeError(\"❌ No default Azure AI Search connection found!\")\n", "print(\"✅ Got search connection\")\n", "\n", "# Step 2: Get embeddings client and check embedding length\n", "embeddings_client = project_client.inference.get_embeddings_client()\n", "print(\"✅ Created embeddings client\")\n", "\n", "sample_doc = health_tips[0]\n", "emb_response = embeddings_client.embed(\n", " model=embedding_model,\n", " input=[sample_doc[\"content\"]]\n", " )\n", "embedding_length = len(emb_response.data[0].embedding)\n", "print(f\"✅ Got embedding length: {embedding_length}\")\n", "\n", "# Step 3: Create the index\n", "create_healthtips_index(\n", " endpoint=search_conn.endpoint_url,\n", " api_key=search_conn.key,\n", " index_name=search_index_name,\n", " dimension=embedding_length # for text-embedding-3-large\n", ")\n", "\n", "# Step 4: Create search client for uploading documents\n", "search_client = SearchClient(\n", " endpoint=search_conn.endpoint_url,\n", " index_name=search_index_name,\n", " credential=AzureKeyCredential(search_conn.key)\n", ")\n", "print(\"✅ Created search client\")\n", "\n", "\n", "# Step 5: Embed and upload documents\n", "search_docs = []\n", "for doc in health_tips:\n", " # Get embedding for document content\n", " emb_response = embeddings_client.embed(\n", " model=embedding_model,\n", " input=[doc[\"content\"]]\n", " )\n", " emb_vec = emb_response.data[0].embedding\n", " \n", " # Create document with embedding\n", " search_docs.append({\n", " \"id\": doc[\"id\"],\n", " \"content\": doc[\"content\"],\n", " \"source\": doc[\"source\"],\n", " \"embedding\": emb_vec,\n", " })\n", "\n", "# Upload documents to index\n", "result = search_client.upload_documents(documents=search_docs)\n", "print(f\"✅ Uploaded {len(search_docs)} documents to search index '{search_index_name}'\")" ] }, { "cell_type": "markdown", "id": "d05e5468", "metadata": {}, "source": [ "## 4. Basic RAG Flow\n", "### 4.1. Retrieve\n", "When a user queries, we:\n", "1. Embed user question.\n", "2. Search vector index with that embedding to get top docs.\n", "\n", "### 4.2. Generate answer\n", "We then pass the retrieved docs to the chat model.\n", "\n", "> In a real scenario, you'd have a more advanced approach to chunking & summarizing. We'll keep it simple.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "c15c3aab", "metadata": {}, "outputs": [], "source": [ "from azure.search.documents.models import VectorizedQuery\n", "\n", "def rag_chat(query: str, top_k: int = 3) -> str:\n", " # 1) Embed user query\n", " user_vec = embeddings_client.embed(\n", " model=embedding_model,\n", " input=[query]).data[0].embedding\n", "\n", " # 2) Vector search using VectorizedQuery\n", " vector_query = VectorizedQuery(\n", " vector=user_vec,\n", " k_nearest_neighbors=top_k,\n", " fields=\"embedding\"\n", " )\n", "\n", " results = search_client.search(\n", " search_text=\"\", # Optional text query\n", " vector_queries=[vector_query],\n", " select=[\"content\", \"source\"] # Only retrieve fields we need\n", " )\n", "\n", " # gather the top docs\n", " top_docs_content = []\n", " for r in results:\n", " c = r[\"content\"]\n", " s = r[\"source\"]\n", " top_docs_content.append(f\"Source: {s} => {c}\")\n", "\n", " # 3) Chat with retrieved docs\n", " system_text = (\n", " \"You are a health & fitness assistant.\\n\"\n", " \"Answer user questions using ONLY the text from these docs.\\n\"\n", " \"Docs:\\n\"\n", " + \"\\n\".join(top_docs_content)\n", " + \"\\nIf unsure, say 'I'm not sure'.\\n\"\n", " )\n", "\n", " with project_client.inference.get_chat_completions_client() as chat_client:\n", " response = chat_client.complete(\n", " model=chat_model,\n", " messages=[\n", " SystemMessage(content=system_text),\n", " UserMessage(content=query)\n", " ]\n", " )\n", " return response.choices[0].message.content" ] }, { "cell_type": "markdown", "id": "dfecfb3c", "metadata": {}, "source": [ "## 5. Try a Query 🎉\n", "Let's do a question about cardio for busy people.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3937fdfc", "metadata": {}, "outputs": [], "source": [ "user_query = \"What's a good short cardio routine for me if I'm busy?\"\n", "answer = rag_chat(user_query)\n", "print(\"🗣️ User Query:\", user_query)\n", "print(\"🤖 RAG Answer:\", answer)" ] }, { "cell_type": "markdown", "id": "a6e562e6", "metadata": {}, "source": [ "## 6. Conclusion\n", "We've demonstrated a **basic RAG** pipeline with:\n", "- **Embedding** docs & storing them in **Azure AI Search**.\n", "- **Retrieving** top docs for user question.\n", "- **Chat** with the retrieved docs.\n", "\n", "🔎 You can expand this by adding advanced chunking, more robust retrieval, and quality checks. Enjoy your healthy coding! 🍎\n", "\n", "\n", "🚀 Want to optimize this further with a small language model? Check out the next notebook [4-phi-4.ipynb](4-phi-4.ipynb) to see how to use Phi-4 using the same Azure AI Foundry SDKs!" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 }