gemini/rag-engine/rag_engine_evaluation.ipynb (1,257 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ur8xi4C7S06n"
},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JAPoU8Sm5E6e"
},
"source": [
"# Advanced RAG Techniques - Vertex RAG Engine Retrieval Quality Evaluation and Hyperparameters Tuning\n",
"\n",
"<table align=\"left\">\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/rag-engine/rag_engine_evaluation.ipynb\">\n",
" <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Frag-engine%2Frag_engine_evaluation.ipynb\">\n",
" <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/rag-engine/rag_engine_evaluation.ipynb\">\n",
" <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/rag-engine/rag_engine_evaluation.ipynb\">\n",
" <img width=\"32px\" src=\"https://www.svgrepo.com/download/217753/github.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
" </a>\n",
" </td>\n",
"</table>\n",
"\n",
"<div style=\"clear: both;\"></div>\n",
"\n",
"<b>Share to:</b>\n",
"\n",
"<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/rag-engine/rag_engine_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/rag-engine/rag_engine_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/rag-engine/rag_engine_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/rag-engine/rag_engine_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/rag-engine/rag_engine_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
"</a> "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "84f0f73a0f76"
},
"source": [
"| | |\n",
"|-----------|---------------------------------------- |\n",
"| Author(s) | [Ed Tsoi](https://github.com/edtsoi430) |"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tvgnzT1CKxrO"
},
"source": [
"## Overview\n",
"\n",
"Retrieval Quality is arguably the most important component of a Retrieval Augmented Generation (RAG) application. Not only does it directly impact the quality of the generated response, in some cases poor retrieval could also lead to irrelevant, incomplete or hallucinated output.\n",
"\n",
"This notebook aims to provide guidelines on:\n",
"> **You'll learn how to:**\n",
"> * Evaluate retrieval quality using the [BEIR-fiqa 2018 dataset](https://arxiv.org/abs/2104.08663) (or your own!).\n",
"> * Understand the impact of key parameters on retrieval performance. (e.g. embedding model, chunk size)\n",
"> * Tune hyperparameters to improve accuracy of the RAG system.\n",
"\n",
"**Note:** This notebook assumes that you already have an understanding on how to implement a RAG system with Vertex AI RAG Engine. For more general instructions on how to use Vertex AI RAG Engine, please refer to the [RAG Engine API Documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/rag-api).\n",
"\n",
"We'll explore how these hyperparameters influence retrieval:\n",
"\n",
"| Parameter | Description |\n",
"|---------------------------|-------------------------------------------------------------------------------------|\n",
"| Chunk Size | Size of each chunk (in tokens). Affects granularity of retrieval. |\n",
"| Chunk Overlap | Overlap between chunks. Helps capture relevant information across chunk boundaries. |\n",
"| Top K | Maximum number of retrieved contexts. Balance recall and precision. |\n",
"| Vector Distance threshold | Filters contexts based on similarity. A stricter threshold prioritizes precision. |\n",
"| Embedding model | Model used to convert text to embeddings. Significantly impacts retrieval accuracy. |\n",
"\n",
"### How exactly could we use this notebook to improve the RAG system?\n",
"\n",
"* **Hyperparameters Tuning:** There are a couple of hyperparameters that could impact retrieval quality:\n",
"\n",
"| Parameter | Description |\n",
"|------------|----------------------|\n",
"| Chunk Size | When documents are ingested into an index, they are split into chunks. The `chunk_size` parameter (in tokens) specifies the size of each chunk. |\n",
"| Chunk Overlap | By default, documents are split into chunks with a certain amount of overlap to improve relevance and retrieval quality. |\n",
"| Top K | Controls the maximum number of contexts that are retrieved. |\n",
"| Vector Distance threshold | Only contexts with a distance smaller than the threshold are considered. |\n",
"| Embedding model | The embedding model used to convert input text into embeddings for retrieval.|\n",
"\n",
"You may use this notebook to evaluate your retrieval quality, and see how changing certain parameters (top k, chunk size) impact or improve your retrieval quality (`recall@k`, `precision@k`, `ndcg@k`).\n",
"\n",
"* **Response Quality Evaluation:** Once you have optimized the retrieval metrics, you can understand how it impacts response quality using the [Evaluation Service API Notebook](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/evaluation/evaluate_rag_gen_ai_evaluation_service_sdk.ipynb)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "61RBz8LLbxCR"
},
"source": [
"## Get started"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "No17Cw5hgx12"
},
"source": [
"### Install Vertex AI SDK and other required packages\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tFy3H3aPgx12"
},
"outputs": [],
"source": [
"%pip install --upgrade --user --quiet google-cloud-aiplatform beir"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R5Xep4W9lq-Z"
},
"source": [
"### Restart runtime\n",
"\n",
"To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n",
"\n",
"The restart might take a minute or longer. After it's restarted, continue to the next step."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "XRvKdaPDTznN"
},
"outputs": [
{
"data": {
"text/plain": [
"{'status': 'ok', 'restart': True}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import IPython\n",
"\n",
"app = IPython.Application.instance()\n",
"app.kernel.do_shutdown(True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SbmM4z7FOBpM"
},
"source": [
"<div class=\"alert alert-block alert-warning\">\n",
"<b>⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️</b>\n",
"</div>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dmWOrTJ3gx13"
},
"source": [
"### Authenticate your notebook environment (Colab only)\n",
"\n",
"If you're running this notebook on Google Colab, run the cell below to authenticate your environment."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "NyKGtVQjgx13"
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
" from google.colab import auth\n",
"\n",
" auth.authenticate_user()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DF4l8DTdWgPY"
},
"source": [
"### Set Google Cloud project information and initialize the Vertex AI SDK\n",
"\n",
"To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
"\n",
"Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "Nqwi-5ufWp_B"
},
"outputs": [],
"source": [
"# Use the environment variable if the user doesn't provide Project ID.\n",
"import os\n",
"\n",
"import vertexai\n",
"\n",
"PROJECT_ID = \"[your-project-id]\" # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n",
"\n",
"if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n",
" PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
"\n",
"LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
"\n",
"vertexai.init(project=PROJECT_ID, location=LOCATION)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "b9caed7b620d"
},
"outputs": [],
"source": [
"!gcloud auth application-default login\n",
"!gcloud auth application-default set-quota-project {PROJECT_ID}\n",
"!gcloud config set project {PROJECT_ID}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "09720c707f1c"
},
"source": [
"### Import libraries"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"id": "e614bd6fe56b"
},
"outputs": [],
"source": [
"from collections.abc import MutableSequence\n",
"import math\n",
"import os\n",
"import re\n",
"import time\n",
"\n",
"from beir import util\n",
"from beir.datasets.data_loader import GenericDataLoader\n",
"from google.cloud import storage\n",
"from google.cloud.aiplatform_v1beta1.types import Context, RetrieveContextsResponse\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"from vertexai.preview import rag"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j0sHOGwdTDXZ"
},
"source": [
"### Define helper function for processing dataset."
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"id": "_k4EWn3TTWVh"
},
"outputs": [],
"source": [
"def convert_beir_to_rag_corpus(\n",
" corpus: dict[str, dict[str, str]], output_dir: str\n",
") -> None:\n",
" \"\"\"\n",
" Convert a BEIR corpus to Vertex RAG corpus format with a maximum of 10,000\n",
" files per subdirectory.\n",
"\n",
" For each document in the BEIR corpus, we will create a new txt where:\n",
" * doc_id will be the file name\n",
" * doc_content will be the document text prepended by title (if any).\n",
"\n",
" Args:\n",
" corpus: BEIR corpus\n",
" output_dir (str): Directory where the converted corpus will be saved\n",
"\n",
" Returns:\n",
" None (will write output to disk)\n",
" \"\"\"\n",
" # Create the output directory if it doesn't exist\n",
" os.makedirs(output_dir, exist_ok=True)\n",
"\n",
" file_count, subdir_count = 0, 0\n",
" current_subdir = os.path.join(output_dir, f\"{subdir_count}\")\n",
" os.makedirs(current_subdir, exist_ok=True)\n",
"\n",
" # Convert each file in the corpus\n",
" for doc_id, doc_content in corpus.items():\n",
" # Combine title and text (if title exists)\n",
" full_text = doc_content.get(\"title\", \"\")\n",
" if full_text:\n",
" full_text += \"\\n\\n\"\n",
" full_text += doc_content[\"text\"]\n",
"\n",
" # Create a new file for each file.\n",
" file_path = os.path.join(current_subdir, f\"{doc_id}.txt\")\n",
" with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
" file.write(full_text)\n",
"\n",
" file_count += 1\n",
"\n",
" # Create a new subdirectory if the current one has reached the limit\n",
" if file_count >= 10000:\n",
" subdir_count += 1\n",
" current_subdir = os.path.join(output_dir, f\"{subdir_count}\")\n",
" os.makedirs(current_subdir, exist_ok=True)\n",
" file_count = 0\n",
"\n",
" print(f\"Conversion complete. {len(corpus)} files saved in {output_dir}\")\n",
"\n",
"\n",
"def count_files_in_gcs_bucket(gcs_path: str) -> int:\n",
" \"\"\"\n",
" Counts the number of files in a Google Cloud Storage path,\n",
" excluding directories and hidden files.\n",
"\n",
" Args:\n",
" gcs_path: The full GCS path, including the bucket name and any prefix.\n",
" * Example: 'gs://my-bucket/my-folder'\n",
"\n",
" Returns:\n",
" The number of files in the GCS path.\n",
" \"\"\"\n",
"\n",
" # Split the GCS path into bucket name and prefix\n",
" bucket_name, prefix = gcs_path.replace(\"gs://\", \"\").split(\"/\", 1)\n",
"\n",
" storage_client = storage.Client()\n",
" bucket = storage_client.bucket(bucket_name)\n",
"\n",
" count = 0\n",
" blobs = bucket.list_blobs(prefix=prefix)\n",
" for blob in blobs:\n",
" if not blob.name.endswith(\"/\") and not any(\n",
" part.startswith(\".\") for part in blob.name.split(\"/\")\n",
" ): # Exclude directories and hidden files\n",
" count += 1\n",
"\n",
" return count\n",
"\n",
"\n",
"def count_directories_after_split(gcs_path: str) -> int:\n",
" \"\"\"\n",
" Counts the number of directories in a Google Cloud Storage path.\n",
"\n",
" Args:\n",
" gcs_path: The full GCS path, including the bucket name and any prefix.\n",
"\n",
" Returns:\n",
" The number of directories in the GCS path.\n",
" \"\"\"\n",
" num_files_in_path = count_files_in_gcs_bucket(gcs_path)\n",
" num_directories = math.ceil(num_files_in_path / 10000)\n",
" return num_directories\n",
"\n",
"\n",
"def import_rag_files_from_gcs(\n",
" paths: list[str], chunk_size: int, chunk_overlap: int, corpus_name: str\n",
") -> None:\n",
" \"\"\"Imports files from Google Cloud Storage to a RAG corpus.\n",
"\n",
" Args:\n",
" paths: A list of GCS paths to import files from.\n",
" chunk_size: The size of each chunk to import.\n",
" chunk_overlap: The overlap between consecutive chunks.\n",
" corpus_name: The name of the RAG corpus to import files into.\n",
"\n",
" Returns:\n",
" None\n",
" \"\"\"\n",
" total_imported, total_num_of_files = 0, 0\n",
"\n",
" for path in paths:\n",
" num_files_to_be_imported = count_files_in_gcs_bucket(path)\n",
" total_num_of_files += num_files_to_be_imported\n",
" max_retries, attempt, imported = 10, 0, 0\n",
" while attempt < max_retries and imported < num_files_to_be_imported:\n",
" response = rag.import_files(\n",
" corpus_name,\n",
" [path],\n",
" chunk_size=chunk_size,\n",
" chunk_overlap=chunk_overlap,\n",
" timeout=20000,\n",
" max_embedding_requests_per_min=1400,\n",
" )\n",
" imported += response.imported_rag_files_count or 0\n",
" attempt += 1\n",
" total_imported += imported\n",
"\n",
" print(f\"{total_imported} files out of {total_num_of_files} imported!\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m9A3qWPmxd2Y"
},
"source": [
"# For step 1, please choose only one of the following options:\n",
"- **1.1 (Option A, Recommended):** Create RagCorpus and perform data ingestion using the provided public GCS bucket (BEIR-fiqa dataset only).\n",
"\n",
"- **1.2 (Option B):** Create RAG Corpus, choose a custom beir dataset and upload/ingest data into the RagCorpus on your own.\n",
"\n",
"- **1.3 (Option C):** Bring your own existing `RagCorpus` (insert `RAG_CORPUS_ID` here).\n",
"\n",
"**Do not run all these cells together.**"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8aTLg2-scmJ7"
},
"source": [
"# 1.1 - Option A (Recommended): Create RagCorpus and perform data ingestion using the provided public GCS bucket (BEIR-fiqa dataset only).\n",
"* This option is recommended to save you time from having to upload evaluation dataset to GCS before we import them into the `RagCorpus`.\n",
"* However, if you would like more flexibility on which BEIR dataset to use, you could go with option B below to upload data to your desired GCS location.\n",
"* If you would like to bring your own rag corpus, simply skip to Option C and specify the rag corpus id."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pKvNjMekrHEU"
},
"source": [
"### Create a `RagCorpus` with the specified configuration (for evaluation)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1asbN-xmrGQ_"
},
"outputs": [],
"source": [
"# See the list of current supported embedding models here: https://cloud.google.com/vertex-ai/generative-ai/docs/rag-overview#supported-embedding-models\n",
"# Select embedding model as desired.\n",
"embedding_model_config = rag.EmbeddingModelConfig(\n",
" publisher_model=\"publishers/google/models/text-embedding-005\" # @param {type:\"string\", isTemplate: true},\n",
")\n",
"\n",
"rag_corpus = rag.create_corpus(\n",
" display_name=\"test-corpus\",\n",
" description=\"A test corpus where we import the BEIR-FiQA-2018 dataset\",\n",
" embedding_model_config=embedding_model_config,\n",
")\n",
"\n",
"print(rag_corpus)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "piSi7bk7Q80-"
},
"source": [
"### Copy beir-fiqa dataset from the public path to a storage bucket in your project."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IXVvmuR5Q7Wg"
},
"outputs": [],
"source": [
"CURRENT_BUCKET_PATH = \"gs://<INSERT_GCS_PATH_HERE>\" # @param {type:\"string\"},\n",
"\n",
"PUBLIC_BEIR_FIQA_GCS_PATH = (\n",
" \"gs://github-repo/generative-ai/gemini/rag-engine/rag_engine_evaluation/beir-fiqa\"\n",
")\n",
"\n",
"!gsutil -m rsync -r -d $PUBLIC_BEIR_FIQA_GCS_PATH $CURRENT_BUCKET_PATH"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5U_MReVRkWr9"
},
"source": [
"### Import evaluation dataset files into `RagCorpus` (configure chunk size, chunk overlap etc as desired)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"id": "AjkOIl_TkVeo"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"57638 files out of 57638 imported!\n"
]
}
],
"source": [
"num_subdirectories = count_directories_after_split(CURRENT_BUCKET_PATH)\n",
"paths = [CURRENT_BUCKET_PATH + f\"/{i}/\" for i in range(num_subdirectories)]\n",
"\n",
"chunk_size = 512 # @param {type:\"integer\"}\n",
"chunk_overlap = 102 # @param {type:\"integer\"}\n",
"\n",
"import_rag_files_from_gcs(\n",
" paths=paths,\n",
" chunk_size=chunk_size,\n",
" chunk_overlap=chunk_overlap,\n",
" corpus_name=rag_corpus.name,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EdvJRUWRNGHE"
},
"source": [
"# 1.2 - Option B: Create RAG Corpus, choose a custom beir dataset and upload/ingest data into the RagCorpus on your own.\n",
"\n",
"* Choose this option if you would like to have more flexibility on which dataset to use. The public, uploaded data in option 1.1 is for `BEIR-fiqa` only.\n",
"* If you would like to bring your own existing `RagCorpus` (with imported files), skip to Option C below."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5303c05f7aa6"
},
"source": [
"### Create a `RagCorpus` with the specified configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6fc324893334"
},
"outputs": [],
"source": [
"# See the list of current supported embedding models here: https://cloud.google.com/vertex-ai/generative-ai/docs/rag-overview#supported-embedding-models\n",
"# You may adjust the embedding model here if you would like.\n",
"embedding_model_config = rag.EmbeddingModelConfig(\n",
" publisher_model=\"publishers/google/models/text-embedding-005\" # @param {type:\"string\", isTemplate: true},\n",
")\n",
"\n",
"rag_corpus = rag.create_corpus(\n",
" display_name=\"test-corpus\",\n",
" description=\"A test corpus where we import the BEIR-FiQA-2018 dataset\",\n",
" embedding_model_config=embedding_model_config,\n",
")\n",
"\n",
"print(rag_corpus)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e43229f3ad4f"
},
"source": [
"### Load BEIR Fiqa dataset (test split).\n",
"- Configure dataset of choice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cf93d5f0ce00"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f4be16c587c04421a7bedbd60803e546",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"datasets/fiqa.zip: 0%| | 0.00/17.1M [00:00<?, ?iB/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e23c337944144b06b47eb44e5b20899a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/57638 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Successfully loaded the fiqa dataset with 57638 files and 648 queries!\n"
]
}
],
"source": [
"# Download and load a BEIR dataset\n",
"dataset = \"fiqa\" # @param [\"arguana\", \"climate-fever\", \"cqadupstack\", \"dbpedia-entity\", \"fever\", \"fiqa\", \"germanquad\", \"hotpotqa\", \"mmarco\", \"mrtydi\", \"msmarco-v2\", \"msmarco\", \"nfcorpus\", \"nq-train\", \"nq\", \"quora\", \"scidocs\", \"scifact\", \"trec-covid-beir\", \"trec-covid-v2\", \"trec-covid\", \"vihealthqa\", \"webis-touche2020\"]\n",
"url = (\n",
" f\"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip\"\n",
")\n",
"out_dir = \"datasets\"\n",
"data_path = util.download_and_unzip(url, out_dir)\n",
"\n",
"# Load the dataset\n",
"corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=\"test\")\n",
"print(\n",
" f\"Successfully loaded the {dataset} dataset with {len(corpus)} files and {len(queries)} queries!\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Me5Ee8yCTDpo"
},
"source": [
"### Convert BEIR corpus to `RagCorpus` format and upload to GCS bucket."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "D0j_ceMjXddK"
},
"outputs": [],
"source": [
"CONVERTED_DATASET_PATH = f\"/converted_dataset_{dataset}\"\n",
"# Convert BEIR corpus to RAG format.\n",
"convert_beir_to_rag_corpus(corpus, CONVERTED_DATASET_PATH)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cQd2c173rPeP"
},
"source": [
"#### Create a test bucket for uploading BEIR evaluation dataset to (or use an existing bucket of your choice)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0mG26u-QrOO2"
},
"outputs": [],
"source": [
"# Optionally rename bucket name here.\n",
"BUCKET_NAME = \"beir-test-bucket\" # @param {type: \"string\"}\n",
"!gsutil mb gs://{BUCKET_NAME}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wIKdALWnp-s5"
},
"source": [
"#### Upload to specified GCS bucket (Modify the GCS bucket path to desired location)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Lk3AGhX1bJas"
},
"outputs": [],
"source": [
"GCS_BUCKET_PATH = \"gs://{BUCKET_NAME}/beir-fiqa\" # @param {type: \"string\"}\n",
"\n",
"!echo \"Uploading files from ${CONVERTED_DATASET_PATH} to ${GCS_BUCKET_PATH}\"\n",
"# Upload RAG format dataset to GCS bucket.\n",
"!gsutil -m rsync -r -d $CONVERTED_DATASET_PATH $GCS_BUCKET_PATH"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KTCo-81oTDzx"
},
"source": [
"### Import evaluation dataset files into `RagCorpus`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v0f2i3PBwQcP"
},
"outputs": [],
"source": [
"num_subdirectories = count_directories_after_split(GCS_BUCKET_PATH)\n",
"paths = [GCS_BUCKET_PATH + f\"/{i}/\" for i in range(num_subdirectories)]\n",
"\n",
"chunk_size = 512 # @param {type:\"integer\"}\n",
"chunk_overlap = 102 # @param {type:\"integer\"}\n",
"\n",
"import_rag_files_from_gcs(\n",
" paths=paths,\n",
" chunk_size=chunk_size,\n",
" chunk_overlap=chunk_overlap,\n",
" corpus_name=rag_corpus.name,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2EioRjxUN2aw"
},
"source": [
"# 1.3 - Option C: Bring your own existing `RagCorpus` (insert `RAG_CORPUS_ID` here)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qDCv3clrN616"
},
"outputs": [],
"source": [
"# Specify your rag corpus ID here that you want to use.\n",
"RAG_CORPUS_ID = \"\" # @param {type: \"string\"}\n",
"\n",
"rag_corpus = rag.get_corpus(\n",
" name=f\"projects/{PROJECT_ID}/locations/{LOCATION}/ragCorpora/{RAG_CORPUS_ID}\"\n",
")\n",
"\n",
"print(rag_corpus)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DXsmEoZNOPLo"
},
"source": [
"# 2. Run Retrieval Quality Evaluation\n",
"\n",
"For Retrieval Quality Evaluation, we focus on the following metrics:\n",
"\n",
"- **Recall@k:**\n",
" - Measures how many of the relevant documents/chunks are successfully retrieved within the top k results\n",
" - Helps evaluate the retrieval component's ability to find ALL relevant information\n",
"- **Precision@k:**\n",
" - Measures the proportion of retrieved documents that are actually relevant within the top k results\n",
" - Helps evaluate how \"focused\" your retrieval is\n",
"- **nDCG@K:**\n",
" - Measures both relevance AND ranking quality\n",
" - Takes into account the position of relevant documents\n",
"\n",
"Follow the Notebook to get these metrics numbers for your configurations, and to optimize your settings."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m5sz02OdOPYk"
},
"source": [
"### Define evaluation helper function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VUHxSZ7hwZ6k"
},
"outputs": [],
"source": [
"def extract_doc_id(file_path: str) -> str | None:\n",
" \"\"\"Extracts the document ID (filename without extension) from a file path.\n",
"\n",
" Handles various potential file name formats and extensions\n",
" like .txt, .pdf, .html, etc.\n",
"\n",
" Args:\n",
" file_path: The path to the file.\n",
"\n",
" Returns:\n",
" The document ID (filename without extension) extracted from the file path.\n",
" \"\"\"\n",
" try:\n",
" # Split the path by directory separators\n",
" parts = file_path.split(\"/\")\n",
" # Get the filename\n",
" filename = parts[-1]\n",
" # Remove the extension (if any)\n",
" filename = re.sub(r\"\\.\\w+$\", \"\", filename) # Removes .txt, .pdf, .html, etc.\n",
" return filename\n",
" except:\n",
" pass # Handle any unexpected errors during extraction\n",
" return None\n",
"\n",
"\n",
"# RAG Engine helper function to extract doc_id, snippet, and score.\n",
"\n",
"\n",
"def extract_retrieval_details(\n",
" response: RetrieveContextsResponse,\n",
") -> tuple[str, str, float]:\n",
" \"\"\"Extracts the document ID, snippet, and score from a retrieval response.\n",
"\n",
" Args:\n",
" response: The retrieval response object.\n",
"\n",
" Returns:\n",
" A tuple containing the document ID, retrieved snippet, and distance score.\n",
" \"\"\"\n",
" doc_id = extract_doc_id(response.source_uri)\n",
" retrieved_snippet = response.text\n",
" distance = response.distance\n",
" return (doc_id, retrieved_snippet, distance)\n",
"\n",
"\n",
"# RAG Engine helper function for retrieval.\n",
"\n",
"\n",
"def rag_api_retrieve(\n",
" query: str, corpus_name: str, top_k: int\n",
") -> MutableSequence[Context]:\n",
" \"\"\"Retrieves relevant contexts from a RAG corpus using the RAG API.\n",
"\n",
" Args:\n",
" query: The query text.\n",
" corpus_name: The name of the RAG corpus, in the format of \"projects/{PROJECT_ID}/locations/{LOCATION}/ragCorpora/{CORPUS_ID}\".\n",
" top_k: The number of top results to retrieve.\n",
"\n",
" Returns:\n",
" A list of retrieved contexts.\n",
" \"\"\"\n",
" return rag.retrieval_query(\n",
" rag_resources=[rag.RagResource(rag_corpus=corpus_name)],\n",
" text=query,\n",
" similarity_top_k=top_k,\n",
" vector_distance_threshold=0.5,\n",
" ).contexts.contexts\n",
"\n",
"\n",
"def calculate_document_level_recall_precision(\n",
" retrieved_response: MutableSequence[Context], cur_qrel: dict[str, int]\n",
") -> tuple[float, float]:\n",
" \"\"\"Calculates the recall and precision for a list of retrieved contexts.\n",
"\n",
" Args:\n",
" retrieved_response: A list of retrieved contexts.\n",
" cur_qrel: A dictionary of ground truth relevant documents for the current query.\n",
"\n",
" Returns:\n",
" A tuple containing the recall and precision scores.\n",
" \"\"\"\n",
" if not retrieved_response:\n",
" return (0, 0)\n",
"\n",
" relevant_retrieved_unique = set()\n",
" num_relevant_retrieved_snippet = 0\n",
" for res in retrieved_response:\n",
" doc_id, text, score = extract_retrieval_details(res)\n",
" if doc_id in cur_qrel:\n",
" relevant_retrieved_unique.add(doc_id)\n",
" num_relevant_retrieved_snippet += 1\n",
" recall = (\n",
" len(relevant_retrieved_unique) / len(cur_qrel.keys())\n",
" if len(cur_qrel.keys()) > 0\n",
" else 0\n",
" )\n",
" precision = (\n",
" num_relevant_retrieved_snippet / len(retrieved_response)\n",
" if len(retrieved_response) > 0\n",
" else 0\n",
" )\n",
" return (recall, precision)\n",
"\n",
"\n",
"def calculate_document_level_metrics(\n",
" queries: dict[str, str],\n",
" qrels: dict[str, dict[str, int]],\n",
" k_values: list[int],\n",
" corpus_name: str,\n",
") -> None:\n",
" \"\"\"Calculates and prints the average recall, precision, and NDCG for a set of queries at different top_k values.\n",
"\n",
" Args:\n",
" queries: A dictionary of queries with query IDs as keys and query text as values.\n",
" qrels: A dictionary of ground truth relevant documents for each query.\n",
" k_values: A list of top_k values to evaluate.\n",
" corpus_name: The name of the RAG corpus, in the format of \"projects/{PROJECT_ID}/locations/{LOCATION}/ragCorpora/{CORPUS_ID}\".\n",
"\n",
" Returns:\n",
" None\n",
" \"\"\"\n",
"\n",
" for top_k in k_values:\n",
" start_time = time.time()\n",
" total_recall, total_precision, total_ndcg = 0, 0, 0\n",
" print(f\"Computing metrics for top_k value: {top_k}\")\n",
" print(f\"Total number of queries: {len(queries)}\")\n",
" for query_id, query in tqdm(\n",
" queries.items(),\n",
" total=len(queries),\n",
" desc=f\"Processing Queries (top_k={top_k})\",\n",
" ):\n",
" response = rag_api_retrieve(query, corpus_name, top_k)\n",
"\n",
" recall, precision = calculate_document_level_recall_precision(\n",
" response, qrels[query_id]\n",
" )\n",
" ndcg = ndcg_at_k(response, qrels[query_id], top_k)\n",
"\n",
" total_recall += recall\n",
" total_precision += precision\n",
" total_ndcg += ndcg\n",
"\n",
" end_time = time.time()\n",
" execution_time = end_time - start_time\n",
" num_queries = len(queries)\n",
" average_recall, average_precision, average_ndcg = (\n",
" total_recall / num_queries,\n",
" total_precision / num_queries,\n",
" total_ndcg / num_queries,\n",
" )\n",
" print(f\"\\nAverage Recall@{top_k}: {average_recall:.4f}\")\n",
" print(f\"Average Precision@{top_k}: {average_precision:.4f}\")\n",
" print(f\"Average nDCG@{top_k}: {average_ndcg:.4f}\")\n",
" print(f\"Execution time: {execution_time} seconds.\")\n",
" print(\"=============================================\")\n",
"\n",
"\n",
"def dcg_at_k_with_zero_padding_if_needed(r: list[int], k: int) -> float:\n",
" \"\"\"Calculates the Discounted Cumulative Gain (DCG) at a given rank k.\n",
"\n",
" Args:\n",
" r: A list of relevance scores.\n",
" k: The rank at which to calculate DCG.\n",
"\n",
" Returns:\n",
" The DCG at rank k.\n",
" \"\"\"\n",
" r = np.asarray(r)[:k]\n",
" if r.size:\n",
" # Pad with zeros if r is shorter than k\n",
" if r.size < k:\n",
" r = np.pad(r, (0, k - r.size))\n",
" return np.sum(np.subtract(np.power(2, r), 1) / np.log2(np.arange(2, k + 2)))\n",
" return 0.0\n",
"\n",
"\n",
"def ndcg_at_k(\n",
" retriever_results: MutableSequence[Context],\n",
" ground_truth_relevances: dict[str, int],\n",
" k: int,\n",
") -> float:\n",
" \"\"\"Calculates the Normalized Discounted Cumulative Gain (NDCG) at a given rank k.\n",
"\n",
" Args:\n",
" retriever_results: A list of retrieved results.\n",
" ground_truth_relevances: A dictionary of ground truth relevance scores for each document.\n",
" k: The rank at which to calculate NDCG.\n",
"\n",
" Returns:\n",
" The NDCG at rank k.\n",
" \"\"\"\n",
" if not retriever_results:\n",
" return 0\n",
"\n",
" # Prepare retriever results\n",
" retrieved_relevances = []\n",
" for res in retriever_results[:k]:\n",
" doc_id, text, score = extract_retrieval_details(res)\n",
" if doc_id in ground_truth_relevances:\n",
" retrieved_relevances.append(ground_truth_relevances[doc_id])\n",
" else:\n",
" retrieved_relevances.append(0) # Assume irrelevant if not in ground truth\n",
"\n",
" # Calculate DCG\n",
" dcg = dcg_at_k_with_zero_padding_if_needed(retrieved_relevances, k)\n",
" # Calculate IDCG\n",
" ideal_relevances = sorted(ground_truth_relevances.values(), reverse=True)\n",
" idcg = dcg_at_k_with_zero_padding_if_needed(ideal_relevances, k)\n",
"\n",
" return dcg / idcg if idcg > 0 else 0.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qyBPq_8fOPbL"
},
"source": [
"### Run Retrieval Quality Evaluation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7092Mp2syWPG"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Computing metrics for top_k value: 5\n",
"Total number of queries: 648\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing Queries (top_k=5): 100%|██████████| 648/648 [44:47<00:00, 4.15s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Average Recall@5: 0.5608\n",
"Average Precision@5: 0.2713\n",
"Average nDCG@5: 0.4450\n",
"Execution time: 2687.608230829239 seconds.\n",
"=============================================\n",
"Computing metrics for top_k value: 10\n",
"Total number of queries: 648\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing Queries (top_k=10): 100%|██████████| 648/648 [37:31<00:00, 3.48s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Average Recall@10: 0.6571\n",
"Average Precision@10: 0.1679\n",
"Average nDCG@10: 0.4039\n",
"Execution time: 2251.886693954468 seconds.\n",
"=============================================\n",
"Computing metrics for top_k value: 100\n",
"Total number of queries: 648\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing Queries (top_k=100): 100%|██████████| 648/648 [38:48<00:00, 3.59s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Average Recall@100: 0.8801\n",
"Average Precision@100: 0.0253\n",
"Average nDCG@100: 0.2592\n",
"Execution time: 2328.4095141887665 seconds.\n",
"=============================================\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"calculate_document_level_metrics(\n",
" queries, qrels, [5, 10, 100], corpus_name=rag_corpus.name\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sfL6wSwyJMTe"
},
"source": [
"# 3. Next steps\n",
"* Once we're done with evaluation, we should carefully examine the metrics number are tune the hypeparameters. Below are some suggestions on how to optimize the hyperparameters to get the best retrieval quality.\n",
"\n",
"### How to optimize Recall:\n",
"* If your recall metrics number is too low, consider the following steps:\n",
" * **Reducing chunk size:** Sometimes important information might be buried within large chunks, making it more difficult to retrieve relevant context. Try reducing the chunk size.\n",
" * **Increasing chunk overlap:** If the chunk overlap is too small, some relevant information at the edge might be lost. Consider increasing the chunk overlap (chunk overlap of 20% of chunk size is generally a good start.)\n",
" * **Increasing top-K:** If your top k is too small, the retriever might miss some relevant information due to a too restrictive context.\n",
"\n",
"### How to optimize Precision:\n",
"* If your precision number is low, consider:\n",
" * **Reducing top-K:** Your top k might be too large, adding a lot of unwanted noise to the retrieved contexts.\n",
" * **Reducing chunk overlap:** Sometimes, too large of a chunk overlap could result in duplicate information.\n",
" * **Increasing chunk size:** If your chunk size is too small, it might lack sufficient context resulting in a low precision score.\n",
"\n",
"### How to optimize nDCG:\n",
"* If your nDCG number is low, consider:\n",
" * **Changing your embedding model:** your embedding model might not capturing relevance well. Consider using a different embedding model (e.g. if your documents are multilingual, consider using a mulilingual embedding model). For more information on the currently supported embedding models, see documentation [here](https://cloud.google.com/vertex-ai/generative-ai/docs/rag-overview#supported-embedding-models).\n",
"\n",
"### Evaluate Response Quality\n",
"* If you want to evaluate response quality (generated answers) on top of retrieval quality, please refer to the [Gen AI Evaluation Service - RAG Evaluation Notebook](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/evaluation/evaluate_rag_gen_ai_evaluation_service_sdk.ipynb)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2a4e033321ad"
},
"source": [
"# 4. Cleaning up (Delete `RagCorpus`)\n",
"\n",
"Once we are done with evaluation, we should clean up the `RagCorpus` to free up resources since we don't need it anymore."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SGIUz542z8To"
},
"outputs": [],
"source": [
"rag.delete_corpus(rag_corpus.name)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"j0sHOGwdTDXZ",
"EdvJRUWRNGHE",
"2EioRjxUN2aw"
],
"name": "rag_engine_evaluation.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}