workshops/rag-ops/2.4_mvp_evaluation.ipynb (1,017 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-PKa6W4wdPWr"
},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "idM_aIPheQDG"
},
"source": [
"# Stage 2: Building MVP: - 04 Evaluation\n",
"\n",
"\n",
"<table align=\"left\">\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.4_mvp_evaluation.ipynb\">\n",
" <img src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fworkshops%2Frag-ops%2F2.4_mvp_evaluation.ipynb\">\n",
" <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
" </a>\n",
" </td> \n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/workshops/rag-ops/2.4_mvp_evaluation.ipynb\">\n",
" <img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br> Open in Workbench\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.4_mvp_evaluation.ipynb\">\n",
" <img src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\"><br> View on GitHub\n",
" </a>\n",
" </td>\n",
"</table>\n",
"\n",
"<div style=\"clear: both;\"></div>\n",
"\n",
"<b>Share to:</b>\n",
"\n",
"<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.4_mvp_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.4_mvp_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.4_mvp_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.4_mvp_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/workshops/rag-ops/2.4_mvp_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
"</a> "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o0oOZgLKfhCi"
},
"source": [
"## Overview\n",
"\n",
"This notebook is the fourth in a series designed to guide you through building a Minimum Viable Product (MVP) for a Multimodal Retrieval Augmented Generation (RAG) system using the Vertex Gemini API.\n",
"\n",
"Having built a functional RAG system with retrieval and generation components in the previous notebook, we now turn our attention to evaluating its performance. This notebook focuses on assessing the quality of the generated answers using two key metrics, providing a comprehensive understanding of your system's strengths and weaknesses.\n",
"\n",
"**Here's what you'll achieve:**\n",
"\n",
"* **Implement Evaluation Metrics:** Gain a deep understanding of two crucial evaluation metrics for RAG systems: Answer Correctness and Context Recall. You'll implement these metrics from scratch, giving you full transparency into the underlying calculations and prompt engineering involved.\n",
"* **Analyze Model Performance:** Apply the implemented metrics to evaluate the answers generated by both Gemini 2.0 and Gemini 2.0 models. This analysis will provide valuable insights into the accuracy and relevance of the generated responses.\n",
"* **Compare and Contrast:** Visualize the performance of both models across all samples of the ground truth data using comparative plots. This visualization will help you identify trends, strengths, and areas for potential improvement in each model's performance.\n",
"* **Enhance Transparency and Understanding:** By coding the evaluation metrics from scratch, you gain a deeper understanding of how they work and can customize them to your specific needs. This transparency allows for more informed decision-making when refining your RAG system.\n",
"\n",
"This notebook provides a crucial step in the iterative development of your RAG system. By rigorously evaluating your system's performance, you can identify areas for improvement and optimize its ability to provide accurate and relevant information. This hands-on approach to evaluation empowers you to build a more robust and reliable RAG MVP.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6KGP8kNhfklW"
},
"source": [
"## Getting Started"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HGFDxQ7_flui"
},
"source": [
"### Install Vertex AI SDK for Python\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "YcWmeELUeTPq"
},
"outputs": [],
"source": [
"%pip install --upgrade --user --quiet google-cloud-aiplatform"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SYYzbNlmfvKJ"
},
"source": [
"### Restart runtime\n",
"\n",
"To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "icVKaoLAfw6c"
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
" import IPython\n",
"\n",
" app = IPython.Application.instance()\n",
" app.kernel.do_shutdown(True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oZoekSN8fy2E"
},
"source": [
"<div class=\"alert alert-block alert-warning\">\n",
"<b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️</b>\n",
"</div>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E4rlvfF1f1RA"
},
"source": [
"### Authenticate your notebook environment (Colab only)\n",
"\n",
"If you are running this notebook on Google Colab, run the cell below to authenticate your environment.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "XTaUAqKLf3N8"
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
" from google.colab import auth\n",
"\n",
" auth.authenticate_user()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iQJQgZKXf6Mx"
},
"source": [
"### Set Google Cloud project information, GCS Bucket and initialize Vertex AI SDK\n",
"\n",
"To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
"\n",
"Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "FqXU2Ptvf5LA"
},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"\n",
"from google.cloud import storage\n",
"import vertexai\n",
"\n",
"PROJECT_ID = \"[your-project-id]\" # @param {type:\"string\"}\n",
"LOCATION = \"us-central1\"\n",
"BUCKET_NAME = \"mlops-for-genai\"\n",
"\n",
"if PROJECT_ID == \"[your-project-id]\":\n",
" PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
"\n",
"if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\" or PROJECT_ID == \"None\":\n",
" raise ValueError(\"Please set your PROJECT_ID\")\n",
"\n",
"\n",
"vertexai.init(project=PROJECT_ID, location=LOCATION)\n",
"\n",
"# Initialize cloud storage\n",
"storage_client = storage.Client(project=PROJECT_ID)\n",
"bucket = storage_client.bucket(BUCKET_NAME)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "6AbGiKORhxQ7"
},
"outputs": [],
"source": [
"# # Variables for data location. Do not change.\n",
"\n",
"PRODUCTION_DATA = \"multimodal-finanace-qa/data/unstructured/production/\"\n",
"PICKLE_FILE_NAME = \"training_data_results.pkl\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s9-h_WOcgAKX"
},
"source": [
"### Import libraries\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "OaW7NsbHgAoo"
},
"outputs": [],
"source": [
"# Library\n",
"import nltk\n",
"\n",
"nltk.download(\"punkt\")\n",
"import pickle\n",
"\n",
"from google.cloud import storage\n",
"import numpy as np\n",
"import pandas as pd\n",
"from rich import print as rich_print\n",
"from vertexai.generative_models import GenerationConfig, GenerativeModel\n",
"from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A9hij3nhgDz8"
},
"source": [
"### Load the Gemini 2.0 models\n",
"\n",
"To learn more about all [Gemini API models on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models).\n",
"\n",
"The Gemini model family has several model versions. You will start by using Gemini 2.0. Gemini 2.0 is a more lightweight, fast, and cost-efficient model. This makes it a great option for prototyping.\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "7xvFkagTgHMc"
},
"outputs": [],
"source": [
"MODEL_ID_FLASH = \"gemini-2.0-flash\" # @param {type:\"string\"}\n",
"MODEL_ID_PRO = \"gemini-2.0-flash\" # @param {type:\"string\"}\n",
"\n",
"\n",
"gemini_15_flash = GenerativeModel(MODEL_ID_FLASH)\n",
"gemini_15_pro = GenerativeModel(MODEL_ID_PRO)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"cellView": "form",
"id": "oS1ar31xYmkR"
},
"outputs": [],
"source": [
"# @title Helper Functions\n",
"\n",
"\n",
"def get_load_dataframes_from_gcs():\n",
" gcs_path = \"multimodal-finanace-qa/data/embeddings/index_db.pkl\"\n",
" # print(\"GCS PAth: \", gcs_path)\n",
" blob = bucket.blob(gcs_path)\n",
"\n",
" # Download the pickle file from GCS\n",
" blob.download_to_filename(f\"{PICKLE_FILE_NAME}\")\n",
"\n",
" # Load the pickle file into a list of dataframes\n",
" with open(f\"{PICKLE_FILE_NAME}\", \"rb\") as f:\n",
" dataframes = pickle.load(f)\n",
"\n",
" # Assign the dataframes to variables\n",
" (\n",
" index_db_final,\n",
" extracted_text_chunk_df,\n",
" video_metadata_chunk_df,\n",
" audio_metadata_chunk_df,\n",
" ) = dataframes\n",
"\n",
" return (\n",
" index_db_final,\n",
" extracted_text_chunk_df,\n",
" video_metadata_chunk_df,\n",
" audio_metadata_chunk_df,\n",
" )\n",
"\n",
"\n",
"def get_load_training_dataframes_from_gcs():\n",
" gcs_path = \"multimodal-finanace-qa/data/structured/\" + PICKLE_FILE_NAME\n",
" # print(\"GCS PAth: \", gcs_path)\n",
" blob = bucket.blob(gcs_path)\n",
"\n",
" # Download the pickle file from GCS\n",
" blob.download_to_filename(f\"{PICKLE_FILE_NAME}\")\n",
"\n",
" # Load the pickle file into a list of dataframes\n",
" with open(f\"{PICKLE_FILE_NAME}\", \"rb\") as f:\n",
" dataframes = pickle.load(f)\n",
"\n",
" # Assign the dataframes to variables\n",
" training_data_flash, training_data_pro = dataframes\n",
"\n",
" return training_data_flash, training_data_pro"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KKT4gg2ybwaR"
},
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "Jk5387FCjlAv"
},
"outputs": [],
"source": [
"# Get the data that has been extracted in the previous step: IndexDB.\n",
"# Make sure that you have ran the previous notebook: stage_2_mvp_chunk_embeddings.ipynb\n",
"\n",
"\n",
"(\n",
" index_db_final,\n",
" extracted_text_chunk_df,\n",
" video_metadata_chunk_df,\n",
" audio_metadata_chunk_df,\n",
") = get_load_dataframes_from_gcs()\n",
"training_data_flash, training_data_pro = get_load_training_dataframes_from_gcs()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "isgFDZBZyZ_1"
},
"outputs": [],
"source": [
"index_db_final.head()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "_8oBQwejBjqV"
},
"outputs": [],
"source": [
"training_data_flash.head(2)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "-aA38s_JBlNB"
},
"outputs": [],
"source": [
"training_data_pro.head(2)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "dtLLgOhH5yv-"
},
"outputs": [],
"source": [
"training_data_pro.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ltfT0cm-lUEc"
},
"source": [
"### Evaluations"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"cellView": "form",
"id": "EOgzkYVGdeCg"
},
"outputs": [],
"source": [
"# @title Answer Correctness\n",
"\n",
"import numpy as np\n",
"\n",
"\n",
"def embed_text(\n",
" texts: list[str] = [\"banana muffins? \", \"banana bread? banana muffins?\"],\n",
" task: str = \"RETRIEVAL_DOCUMENT\",\n",
" model_name: str = \"text-embedding-005\",\n",
") -> list[list[float]]:\n",
" \"\"\"Embeds texts with a pre-trained, foundational model.\"\"\"\n",
" model = TextEmbeddingModel.from_pretrained(model_name)\n",
" inputs = [TextEmbeddingInput(text, task) for text in texts]\n",
" embeddings = model.get_embeddings(inputs)\n",
" return [embedding.values for embedding in embeddings][0]\n",
"\n",
"\n",
"def calculate_final_score(semantic_similarity, score, weights):\n",
" \"\"\"\n",
" Calculates the final score as a weighted average of semantic similarity and factual similarity.\n",
"\n",
" Args:\n",
" semantic_similarity: Float value representing semantic similarity.\n",
" score: Float value representing factual similarity (or another score).\n",
" weights: A list or tuple of two float values representing the weights for semantic similarity and factual similarity, respectively. The weights should sum to 1.\n",
"\n",
" Returns:\n",
" The final score as a float value.\n",
" \"\"\"\n",
"\n",
" # Ensure weights are valid\n",
" assert len(weights) == 2, \"Weights must have two values\"\n",
" assert sum(weights) == 1, \"Weights must sum to 1\"\n",
"\n",
" final_score = (weights[0] * semantic_similarity) + (weights[1] * score)\n",
" return round(final_score, 2)\n",
"\n",
"\n",
"def get_answer_correctness(data_samples, debug=False, num_runs=3, retry=2):\n",
"\n",
" question = data_samples[\"question\"]\n",
" answer = data_samples[\"answer\"]\n",
" ground_truth = data_samples[\"ground_truth\"]\n",
"\n",
" response_schema = {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"TP\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}},\n",
" \"FP\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}},\n",
" \"FN\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}},\n",
" },\n",
" }\n",
"\n",
" prompt = f\"\"\"Task: Analyze the Question, Answer, and Ground Truth to identify and categorize information.\n",
"Provide your response in JSON format with the following structure:\n",
"{{\n",
" \"TP\": [\"list of true positive statements\"],\n",
" \"FP\": [\"list of false positive statements\"],\n",
" \"FN\": [\"list of false negative statements\"]\n",
"}}\n",
"\n",
"Here's how to categorize the information:\n",
"\n",
"* **TP (True Positive):** Statements that are factually correct and present in BOTH the Answer and Ground Truth.\n",
" *Example:* If the Ground Truth says \"The capital of France is Paris\" and the Answer says \"Paris is the capital of France\", then \"The capital of France is Paris\" is a TP.\n",
"* **FP (False Positive):** Statements present in the Answer BUT NOT factually correct according to the Ground Truth.\n",
" *Example:* If the Answer says \"The Earth is flat\", but the Ground Truth does not support this, then \"The Earth is flat\" is an FP.\n",
"* **FN (False Negative):** Factually correct statements present in the Ground Truth BUT missing from the Answer.\n",
" *Example:* If the Ground Truth says \"The sky is blue\" but the Answer doesn't mention this fact, then \"The sky is blue\" is an FN.\n",
"\n",
"Question: {question}\n",
"\n",
"Answer: {answer}\n",
"\n",
"Ground Truth: {ground_truth}\n",
"\n",
"Analyze the answer and extract the TP, FP, and FN based on factual correctness.\n",
"\"\"\"\n",
"\n",
" cg_model = GenerativeModel(\n",
" model_name=\"gemini-2.0-flash\",\n",
" generation_config=GenerationConfig(\n",
" response_mime_type=\"application/json\", response_schema=response_schema\n",
" ),\n",
" )\n",
"\n",
" scores = []\n",
" for _ in range(num_runs):\n",
" for _ in range(retry):\n",
" response = cg_model.generate_content(prompt)\n",
" try:\n",
" tp = len(eval(response.text)[\"TP\"])\n",
" fp = len(eval(response.text)[\"FP\"])\n",
" fn = len(eval(response.text)[\"FN\"])\n",
" break # Break out of the retry loop if successful\n",
" except KeyError:\n",
" print(\"Retrying...\") # Indicate a retry is happening\n",
"\n",
" f1_score = tp / (tp + 0.5 * (fp + fn)) if tp > 0 else 0\n",
"\n",
" answer_embedding = embed_text(answer)\n",
" ground_truth_embedding = embed_text(ground_truth)\n",
" semantic_similarity = round(\n",
" np.dot(answer_embedding[0], ground_truth_embedding[0]), 2\n",
" )\n",
"\n",
" weights = (\n",
" 0.1,\n",
" 0.9,\n",
" ) # weightage to semantic similarity, weightage for factual similarity\n",
"\n",
" final_score = calculate_final_score(semantic_similarity, f1_score, weights)\n",
" scores.append(final_score)\n",
"\n",
" if debug:\n",
" rich_print(response.text)\n",
" print(\"F1 Score: \", f1_score)\n",
" print(\"Semantic Similarity: \", semantic_similarity)\n",
" print(\"Final Score: Answer Correctness :\", final_score)\n",
"\n",
" return np.max(scores) # Return the average score\n",
"\n",
"\n",
"def get_answer_correctness_row_wise(row, num_runs=3, retry=2):\n",
" data_samples = {\n",
" \"question\": [row[\"question\"]],\n",
" \"answer\": [row[\"gen_answer\"]],\n",
" \"ground_truth\": [row[\"answer\"]],\n",
" }\n",
"\n",
" score = get_answer_correctness(data_samples, num_runs=num_runs, retry=retry)\n",
" return score"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HZ__LP7nl-iZ"
},
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"cellView": "form",
"id": "L9nDuI6jhKad"
},
"outputs": [],
"source": [
"# @title Context Recall\n",
"\n",
"\n",
"import nltk\n",
"\n",
"\n",
"def get_context_precision_prompt(question, answer_chunk_list, context):\n",
" return f\"\"\"Task: Given a context, an answer (broken down into statements), and a question, carefully analyze EACH statement in the answer and classify if it can be logically inferred or directly attributed to the given context.\n",
"\n",
" Provide your response in JSON format with the following structure for each statement in the answer:\n",
" [\n",
" {{\n",
" \"statement\": \"<the original statement from the answer>\",\n",
" \"attributed\": <1 if attributed to the context, 0 otherwise>,\n",
" \"reason\": \"<a concise explanation for the attribution>\"\n",
" }},\n",
" {{\n",
" \"statement\": \"<the original statement from the answer>\",\n",
" \"attributed\": <1 if attributed to the context, 0 otherwise>,\n",
" \"reason\": \"<a concise explanation for the attribution>\"\n",
" }},\n",
" ...\n",
" ]\n",
"\n",
" Instructions:\n",
" - Focus on the meaning and implications of both the context and EACH statement in the answer.\n",
" - Consider if the statement is a reasonable deduction or a paraphrase of information present in the context.\n",
" - If the statement introduces new information not mentioned or implied in the context, it should be classified as not attributed.\n",
" - Ensure that you provide a response for EVERY statement in the answer, without any repetitions or omissions\n",
"\n",
"\n",
" Question: {question}\n",
" Answer Statements: {answer_chunk_list}\n",
" Context: {context}\n",
" \"\"\"\n",
"\n",
"\n",
"def calculate_context_precision(question, answer, context, num_runs=3, retry=2):\n",
" \"\"\"\n",
" Calculates context precision based on the provided question, answer, and context.\n",
"\n",
" Args:\n",
" question: The question asked.\n",
" answer: The answer given.\n",
" context: The context provided.\n",
" num_runs: The number of times to run the evaluation.\n",
" retry: The number of times to retry if the format is incorrect.\n",
"\n",
" Returns:\n",
" The calculated context precision value (between 0 and 1).\n",
" \"\"\"\n",
" response_schema = {\n",
" \"type\": \"array\",\n",
" \"items\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"statement\": {\"type\": \"string\"},\n",
" \"reason\": {\"type\": \"string\"},\n",
" \"attributed\": {\"type\": \"integer\"},\n",
" },\n",
" },\n",
" }\n",
"\n",
" cg_model = GenerativeModel(\n",
" model_name=\"gemini-2.0-flash\",\n",
" generation_config=GenerationConfig(\n",
" response_mime_type=\"application/json\", response_schema=response_schema\n",
" ),\n",
" )\n",
"\n",
" answer_chunk_list = nltk.sent_tokenize(answer)\n",
"\n",
" precision_scores = []\n",
" for _ in range(num_runs):\n",
" for _ in range(retry):\n",
" try:\n",
" response = cg_model.generate_content(\n",
" get_context_precision_prompt(question, answer_chunk_list, context)\n",
" ).text\n",
" evaluation_results = eval(response)\n",
"\n",
" # Check if the response has the expected format\n",
" if isinstance(evaluation_results, list) and all(\n",
" isinstance(result, dict) and \"attributed\" in result\n",
" for result in evaluation_results\n",
" ):\n",
" break # Break out of the retry loop if successful\n",
" else:\n",
" print(\"Retrying...\")\n",
" except (SyntaxError, ValueError): # Catch errors in eval\n",
" print(\"Retrying...\")\n",
"\n",
" # Count the total number of statements and the number of statements attributed to the context\n",
" attributed_statements = sum(\n",
" 1 for result in evaluation_results if result.get(\"attributed\") == 1\n",
" )\n",
" total_statements = len(evaluation_results)\n",
"\n",
" # Calculate context precision\n",
" context_precision = (\n",
" attributed_statements / total_statements if total_statements > 0 else 0\n",
" )\n",
" precision_scores.append(context_precision)\n",
"\n",
" return np.max(precision_scores) # Return the average\n",
"\n",
"\n",
"def get_context_precision_row_wise(row, num_runs=3, retry=2):\n",
" question = row[\"question\"]\n",
" answer = row[\"gen_answer\"]\n",
" context = \"\\n\".join([each_cit[\"content\"] for each_cit in row[\"citation\"][:20]])\n",
"\n",
" context_precision = calculate_context_precision(\n",
" question, answer, context, num_runs=num_runs, retry=retry\n",
" )\n",
" return context_precision"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IdJu6D2GmC3r"
},
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "GFXvzNRsdgNe"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"# Gemini 2.0\n",
"training_data_pro_subset = training_data_pro[\n",
" [\"question\", \"answer\", \"gen_answer\", \"citation\"]\n",
"]\n",
"\n",
"# Specify the desired number of runs and retries # re-run the cell if you get ServiceUnavailable: 503 502:Bad Gateway\n",
"num_runs = 3\n",
"retry_attempts = 3\n",
"\n",
"training_data_pro_subset[\"answer_correctness\"] = training_data_pro_subset.apply(\n",
" lambda row: get_answer_correctness_row_wise(\n",
" row, num_runs=num_runs, retry=retry_attempts\n",
" ),\n",
" axis=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "hm6qpDy4oqht"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"# Gemini 2.0\n",
"training_data_flash_subset = training_data_flash[\n",
" [\"question\", \"answer\", \"gen_answer\", \"citation\"]\n",
"]\n",
"\n",
"# Specify the desired number of runs and retries # re-run the cell if you get ServiceUnavailable: 503 502:Bad Gateway\n",
"num_runs = 3\n",
"retry_attempts = 3\n",
"\n",
"training_data_flash_subset[\"answer_correctness\"] = training_data_flash_subset.apply(\n",
" lambda row: get_answer_correctness_row_wise(\n",
" row, num_runs=num_runs, retry=retry_attempts\n",
" ),\n",
" axis=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "CJ2NL1ZQhO8y"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"# Specify the desired number of runs and retries # re-run the cell if you get ServiceUnavailable: 503 502:Bad Gateway\n",
"\n",
"num_runs = 3\n",
"retry_attempts = 3\n",
"\n",
"training_data_pro_subset[\"context_recall\"] = training_data_pro_subset.apply(\n",
" lambda row: get_context_precision_row_wise(\n",
" row, num_runs=num_runs, retry=retry_attempts\n",
" ),\n",
" axis=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "cckwsIcsopVX"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"# Specify the desired number of runs and retries # re-run the cell if you get ServiceUnavailable: 503 502:Bad Gateway\n",
"\n",
"num_runs = 3\n",
"retry_attempts = 3\n",
"\n",
"training_data_flash_subset[\"context_recall\"] = training_data_flash_subset.apply(\n",
" lambda row: get_context_precision_row_wise(\n",
" row, num_runs=num_runs, retry=retry_attempts\n",
" ),\n",
" axis=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "5E0dHNDiIxie"
},
"outputs": [],
"source": [
"training_data_pro_subset.head()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "PxMSTJHJAp6w"
},
"outputs": [],
"source": [
"training_data_flash_subset.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cVSm6E2HaaqY"
},
"source": [
"### Comparing the results"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "NcJtnyEigSPj"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Assuming training_data_pro_subset and training_data_flash_subset are your DataFrames\n",
"\n",
"# Combine the data\n",
"combined_data = pd.DataFrame(\n",
" {\n",
" \"Gemini 2.0\": training_data_pro_subset[\"answer_correctness\"],\n",
" \"Gemini 2.0\": training_data_flash_subset[\"answer_correctness\"],\n",
" }\n",
")\n",
"\n",
"# Plot the combined data\n",
"combined_data.plot(kind=\"bar\", title=\"Answer Correctness Comparison\")\n",
"plt.xlabel(\"Index\") # Or any other relevant label for your x-axis\n",
"plt.ylabel(\"Answer Correctness\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"id": "9Svg8BhJaMpK"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Assuming training_data_pro_subset and training_data_flash_subset are your DataFrames\n",
"\n",
"# Combine the data\n",
"combined_data = pd.DataFrame(\n",
" {\n",
" \"Gemini 2.0\": training_data_pro_subset[\"context_recall\"],\n",
" \"Gemini 2.0\": training_data_flash_subset[\"context_recall\"],\n",
" }\n",
")\n",
"\n",
"# Plot the combined data\n",
"combined_data.plot(kind=\"bar\", title=\"Context Recall Comparison\")\n",
"plt.xlabel(\"Index\") # Or any other relevant label for your x-axis\n",
"plt.ylabel(\"Context Recall\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "Dhl1XQds8oY7"
},
"outputs": [],
"source": [
"index = 5\n",
"\n",
"print(\"*******The question: *******\\n\")\n",
"rich_print(training_data_pro[\"question\"][index])\n",
"\n",
"print(\"\\n*******The ground-truth answer:*******\\n\")\n",
"rich_print(training_data_pro[\"answer\"][index])\n",
"\n",
"print(\"\\n*******The generated answer - Gemini 2.0: *******\\n\")\n",
"rich_print(training_data_pro[\"gen_answer\"][index])\n",
"\n",
"print(\"\\n*******The answer_correctness - Gemini 2.0: *******\\n\")\n",
"rich_print(training_data_pro_subset[\"answer_correctness\"][index])\n",
"\n",
"print(\"\\n*******The context_recall - Gemini 2.0: *******\\n\")\n",
"rich_print(training_data_pro_subset[\"context_recall\"][index])\n",
"\n",
"print(\"\\n*******The generated answer - Gemini 2.0: *******\\n\")\n",
"rich_print(training_data_flash[\"gen_answer\"][index])\n",
"\n",
"print(\"\\n*******The answer_correctness - Gemini 2.0: *******\\n\")\n",
"rich_print(training_data_flash_subset[\"answer_correctness\"][index])\n",
"\n",
"print(\"\\n*******The context_recall - Gemini 2.0: *******\\n\")\n",
"rich_print(training_data_flash_subset[\"context_recall\"][index])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oyH6gsU91TRG"
},
"source": [
"### Save the intermediate Files"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "i6VI5v_R1TlX"
},
"outputs": [],
"source": [
"# # [Optional]\n",
"\n",
"# import pickle\n",
"\n",
"# pickle_file_name =\"training_data_eval_results.pkl\"\n",
"# data_to_dump = [training_data_pro_subset, training_data_flash_subset]\n",
"\n",
"# gcs_location = f\"gs://mlops-for-genai/multimodal-finanace-qa/data/structured/{pickle_file_name}\"\n",
"\n",
"# with open(f\"{pickle_file_name}\", \"wb\") as f:\n",
"# pickle.dump(data_to_dump, f)\n",
"\n",
"\n",
"# # Upload the pickle file to GCS\n",
"# !gsutil cp {pickle_file_name} {gcs_location}"
]
}
],
"metadata": {
"colab": {
"name": "2.4_mvp_evaluation.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}