open-models/evaluation/vertex_ai_tgi_gemma_with_genai_evaluation.ipynb (1,272 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ur8xi4C7S06n"
},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JAPoU8Sm5E6e"
},
"source": [
"# Hugging Face DLCs: Using Gemma for running evaluations with Vertex AI Gen AI Evaluation\n",
"\n",
"<table align=\"left\">\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/open-models/evaluation/vertex_ai_tgi_gemma_with_genai_evaluation.ipynb\">\n",
" <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fopen-models%2Fevaluation%2Fvertex_ai_tgi_gemma_with_genai_evaluation.ipynb\">\n",
" <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/open-models/evaluation/vertex_ai_tgi_gemma_with_genai_evaluation.ipynb\">\n",
" <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/evaluation/vertex_ai_tgi_gemma_with_genai_evaluation.ipynb\">\n",
" <img width=\"32px\" src=\"https://www.svgrepo.com/download/217753/github.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
" </a>\n",
" </td>\n",
"</table>\n",
"\n",
"<div style=\"clear: both;\"></div>\n",
"\n",
"<b>Share to:</b>\n",
"\n",
"<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/evaluation/vertex_ai_tgi_gemma_with_genai_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/evaluation/vertex_ai_tgi_gemma_with_genai_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/evaluation/vertex_ai_tgi_gemma_with_genai_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/evaluation/vertex_ai_tgi_gemma_with_genai_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/evaluation/vertex_ai_tgi_gemma_with_genai_evaluation.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
"</a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h-zz1_z92CIJ"
},
"source": [
"| | |\n",
"|-|-|\n",
"|Author(s) | [Ivan Nardini](https://github.com/inardini) |"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tvgnzT1CKxrO"
},
"source": [
"## Overview\n",
"\n",
"Assessing the performance of Large Language Models (LLMs) remains a complex task, especially when it comes to integrating them into production systems. Unlike conventional software and non-generative machine learning models, evaluating LLMs is subjective, challenging to automate, and prone to highly visible errors.\n",
"\n",
"To tackle these challenges, Vertex AI offers a comprehensive evaluation framework through its Gen AI Evaluation service. This framework encompasses the entire LLM lifecycle, from prompt engineering and model comparison to operationalizing automated model evaluation in production environments.\n",
"\n",
"Learn more about [Vertex AI Gen AI Evaluation service](https://cloud.google.com/vertex-ai/generative-ai/docs/models/evaluate-models).\n",
"\n",
"## Objective\n",
"\n",
"In this tutorial, you learn how to use the Vertex AI Gen AI Evaluation framework to evaluate Gemma 2 in a summarization task.\n",
"\n",
"This tutorial uses the following Google Cloud ML services and resources:\n",
"\n",
"- Vertex AI Model Garden\n",
"- Vertex AI Prediction\n",
"- Vertex AI Model Eval\n",
"\n",
"The steps performed include:\n",
"\n",
"- Evaluate Gemma 2 for summarization task.\n",
"- Use Gemma 2 as LLM-as-Judge to evaluate generated summaries."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "61RBz8LLbxCR"
},
"source": [
"## Get started"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "No17Cw5hgx12"
},
"source": [
"### Install Vertex AI SDK for Python and other required packages\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tFy3H3aPgx12"
},
"outputs": [],
"source": [
"%pip install --upgrade --user --quiet google-cloud-aiplatform[evaluation]\n",
"%pip install --upgrade --user --quiet fsspec datasets\n",
"%pip install --upgrade --user --quiet plotly"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R5Xep4W9lq-Z"
},
"source": [
"### Restart runtime (Colab only)\n",
"\n",
"To use the newly installed packages, you must restart the runtime on Google Colab."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XRvKdaPDTznN"
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
"\n",
" import IPython\n",
"\n",
" app = IPython.Application.instance()\n",
" app.kernel.do_shutdown(True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SbmM4z7FOBpM"
},
"source": [
"<div class=\"alert alert-block alert-warning\">\n",
"<b>⚠️ The kernel is going to restart. Wait until it is finished before continuing to the next step. ⚠️</b>\n",
"</div>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dmWOrTJ3gx13"
},
"source": [
"### Authenticate your notebook environment (Colab only)\n",
"\n",
"Authenticate your environment on Google Colab.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NyKGtVQjgx13"
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
"\n",
" from google.colab import auth\n",
"\n",
" auth.authenticate_user()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f9af3e57f89a"
},
"source": [
"### Authenticate your Hugging Face account\n",
"\n",
"As [`google/gemma-2-9b-it`](https://huggingface.co/google/gemma-2-9b-it) is a gated model, you are required to review and agree to Google usage license on the Hugging Face Hub for any of the models from the [Gemma 2 release collection](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b), and the access request will be processed inmediately.\n",
"\n",
"Once this is done, you need to generate a new user access token with read-only access so that the weights can be downloaded from the Hub in the Hugging Face DLC for TGI.\n",
"\n",
"> Note that the user access token can only be generated via [the Hugging Face Hub UI](https://huggingface.co/settings/tokens/new), where you can either select read-only access to your account, or follow the recommendations and generate a fine-grained token with read-only access to [`google/gemma-2-9b-it`](https://huggingface.co/google/google/gemma-2-9b-it)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4c31c7272804"
},
"source": [
"Then you can install the `huggingface_hub` that comes with a CLI that will be used for the authentication with the token generated in advance. So that then the token can be safely retrieved via `huggingface_hub.get_token`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8d836e0210fe"
},
"outputs": [],
"source": [
"from huggingface_hub import interpreter_login\n",
"\n",
"interpreter_login()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c71a4314c250"
},
"source": [
"Read more about [Hugging Face Security](https://huggingface.co/docs/hub/en/security), specifically about [Hugging Face User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dy5VO78rzX8c"
},
"source": [
"### Requirements"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R8Zm9y0hxU5O"
},
"source": [
"#### Set Project ID and Location\n",
"\n",
"To get started using Vertex AI, you must have an existing Google Cloud project and [enable these APIs](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com,artifactregistry.googleapis.com).\n",
"\n",
"Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6Z4TNTjpik7q"
},
"outputs": [],
"source": [
"# Use the environment variable if the user does not provide Project ID.\n",
"import os\n",
"\n",
"PROJECT_ID = \"[your-project-id]\" # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n",
"\n",
"if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n",
" PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
"\n",
"PROJECT_NUMBER = !gcloud projects describe {PROJECT_ID} --format=\"get(projectNumber)\"[0]\n",
"PROJECT_NUMBER = PROJECT_NUMBER[0]\n",
"\n",
"LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "set_service_account"
},
"source": [
"#### Set Service Account and permissions\n",
"\n",
"You will need to have the Vertex AI User (roles/aiplatform.user) IAM role.\n",
"\n",
"For more information about granting roles, see [Manage access](https://cloud.google.com/iam/docs/granting-changing-revoking-access).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VX9tpdtuQI5L"
},
"source": [
"> If you run following commands using Vertex AI Workbench, run directly in the terminal.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ssUJJqXJJHgC"
},
"outputs": [],
"source": [
"SERVICE_ACCOUNT = f\"{PROJECT_NUMBER}-compute@developer.gserviceaccount.com\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wqOHg5aid6HP"
},
"outputs": [],
"source": [
"! gcloud projects add-iam-policy-binding {PROJECT_ID} \\\n",
" --member=serviceAccount:{SERVICE_ACCOUNT} \\\n",
" --role=roles/aiplatform.user --condition=None"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EjebjdiNxe_D"
},
"source": [
"### Initiate Vertex AI SDK for Python\n",
"\n",
"Initiate Vertex AI client session."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GReulMr0xjZs"
},
"outputs": [],
"source": [
"import vertexai\n",
"\n",
"vertexai.init(project=PROJECT_ID, location=LOCATION)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UBQgjn5wOvFq"
},
"source": [
"### Import libraries\n",
"\n",
"Import relevant libraries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-TmyCxUSOvFq"
},
"outputs": [],
"source": [
"import json\n",
"import logging\n",
"import random\n",
"import string\n",
"from typing import Any\n",
"import warnings\n",
"\n",
"from IPython.display import Markdown, display\n",
"import datasets\n",
"from google.cloud import aiplatform\n",
"from huggingface_hub import get_token\n",
"import pandas as pd\n",
"import plotly.graph_objects as go\n",
"from tenacity import retry, wait_random_exponential\n",
"from transformers import AutoTokenizer\n",
"from vertexai import generative_models\n",
"from vertexai.evaluation import CustomMetric, EvalTask\n",
"from vertexai.generative_models import (\n",
" Content,\n",
" GenerationConfig,\n",
" GenerativeModel,\n",
" Part,\n",
" SafetySetting,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tfQ7sPtOjZOw"
},
"source": [
"### Library settings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RjWUgU1TjZOw"
},
"outputs": [],
"source": [
"logging.getLogger(\"urllib3.connectionpool\").setLevel(logging.ERROR)\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F_Gw6YLeOvFq"
},
"source": [
"### Helper functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x8imb3UdOvFq"
},
"outputs": [],
"source": [
"def generate_uuid(length: int = 8) -> str:\n",
" \"\"\"Generate a uuid of a specified length (default=8).\"\"\"\n",
" return \"\".join(random.choices(string.ascii_lowercase + string.digits, k=length))\n",
"\n",
"\n",
"def init_new_model(\n",
" model_name: str,\n",
" generation_config: GenerationConfig | None = None,\n",
" safety_settings: list[SafetySetting] | None = None,\n",
" **kwargs: Any,\n",
") -> GenerativeModel:\n",
" \"\"\"Initialize a new model with configurable generation and safety settings.\"\"\"\n",
"\n",
" if generation_config is None:\n",
" generation_config = GenerationConfig(\n",
" candidate_count=1, max_output_tokens=2048, temperature=0\n",
" )\n",
" if safety_settings is None:\n",
" safety_settings = [\n",
" generative_models.SafetySetting(\n",
" category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH,\n",
" method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY,\n",
" threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,\n",
" ),\n",
" generative_models.SafetySetting(\n",
" category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,\n",
" method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY,\n",
" threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,\n",
" ),\n",
" generative_models.SafetySetting(\n",
" category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,\n",
" method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY,\n",
" threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,\n",
" ),\n",
" generative_models.SafetySetting(\n",
" category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,\n",
" method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY,\n",
" threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,\n",
" ),\n",
" ]\n",
"\n",
" model = GenerativeModel(\n",
" model_name=model_name,\n",
" generation_config=generation_config,\n",
" safety_settings=safety_settings,\n",
" **kwargs,\n",
" )\n",
" return model\n",
"\n",
"\n",
"@retry(wait=wait_random_exponential(multiplier=1, max=120))\n",
"async def async_generate(\n",
" prompt: str,\n",
" model: GenerativeModel,\n",
" **kwargs: Any,\n",
") -> str | None:\n",
" \"\"\"Generates a response from the model, optionally handling function calls.\"\"\"\n",
"\n",
" user_prompt_content = Content(role=\"user\", parts=[Part.from_text(prompt)])\n",
"\n",
" try:\n",
" # Initial generation - potentially calling a function.\n",
" response = await model.generate_content_async(\n",
" prompt,\n",
" **kwargs,\n",
" )\n",
"\n",
" # Extract and return text if generation was successful\n",
" if response and response.candidates and response.candidates[0].content.parts:\n",
" return (\n",
" response.candidates[0].content.parts[0].text\n",
" ) # More robust text extraction\n",
" return None\n",
"\n",
" except Exception as e: # pylint: disable=broad-except\n",
" print(f\"Error calling the model: {e}\") # Include the actual error message\n",
" return \"Could not call the model. Please try it again in a few minutes.\"\n",
"\n",
"\n",
"def display_eval_report(\n",
" eval_result: pd.DataFrame, title: str, metrics: list[str] = None\n",
") -> None:\n",
" \"\"\"Display the evaluation results.\"\"\"\n",
"\n",
" summary_metrics, report_df = eval_result.summary_metrics, eval_result.metrics_table\n",
" metrics_df = pd.DataFrame.from_dict(summary_metrics, orient=\"index\").T\n",
" if metrics:\n",
" metrics_df = metrics_df.filter(\n",
" [\n",
" metric\n",
" for metric in metrics_df.columns\n",
" if any(selected_metric in metric for selected_metric in metrics)\n",
" ]\n",
" )\n",
" report_df = report_df.filter(\n",
" [\n",
" metric\n",
" for metric in report_df.columns\n",
" if any(selected_metric in metric for selected_metric in metrics)\n",
" ]\n",
" )\n",
"\n",
" # Display the title with Markdown for emphasis\n",
" display(Markdown(f\"## {title}\"))\n",
"\n",
" # Display the metrics DataFrame\n",
" display(Markdown(\"### Summary Metrics\"))\n",
" display(metrics_df)\n",
"\n",
" # Display the detailed report DataFrame\n",
" display(Markdown(\"### Report Metrics\"))\n",
" display(report_df)\n",
"\n",
"\n",
"def display_explanations(\n",
" df: pd.DataFrame, metrics: list[str] = None, n: int = 1\n",
") -> None:\n",
" \"\"\"Display the explanations for the evaluation results.\"\"\"\n",
"\n",
" # Sample the DataFrame\n",
" df = df.sample(n=n)\n",
"\n",
" # Filter the DataFrame based on the selected metrics\n",
" if metrics:\n",
" df = df.filter(\n",
" [\"instruction\", \"context\", \"reference\", \"completed_prompt\", \"response\"]\n",
" + [\n",
" metric\n",
" for metric in df.columns\n",
" if any(selected_metric in metric for selected_metric in metrics)\n",
" ]\n",
" )\n",
"\n",
" # Display the explanations using Markdown for consistent styling\n",
" for index, row in df.iterrows():\n",
" display(Markdown(\"---\")) # Section separator\n",
" for col in df.columns:\n",
" display(Markdown(f\"### {col}\"))\n",
" display(Markdown(f\"{row[col]}\"))\n",
"\n",
"\n",
"def plot_bar_plot(\n",
" eval_result: pd.DataFrame, title: str, metrics: list[str] = None\n",
") -> None:\n",
" fig = go.Figure()\n",
" data = []\n",
"\n",
" summary_metrics = eval_result.summary_metrics\n",
" if metrics:\n",
" summary_metrics = {\n",
" k: summary_metrics[k]\n",
" for k, v in summary_metrics.items()\n",
" if any(selected_metric in k for selected_metric in metrics)\n",
" }\n",
"\n",
" data.append(\n",
" go.Bar(\n",
" x=list(summary_metrics.keys()),\n",
" y=list(summary_metrics.values()),\n",
" name=title,\n",
" )\n",
" )\n",
"\n",
" fig = go.Figure(data=data)\n",
"\n",
" # Change the bar mode\n",
" fig.update_layout(barmode=\"group\")\n",
" fig.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "abWMeETlldVD"
},
"source": [
"## Initiate Gemma 2 on Vertex AI from Hugging Face Hub\n",
"\n",
"To use Gemma 2 with Vertex AI Gen AI evaluation, you need to deploy the model on Vertex AI.\n",
"\n",
"To deploy Gemma 2 on Vertex AI from Hugging Face Hub, register the model on Vertex AI Model Registry using Hugging Face Deep Learning Container. This requires to specify the container image for serving the model and configure essential environment variables. Before deploying the model, create an endpoint, a dedicated resource on Vertex AI that serves as an entry point for predictions. Finally, deploy the registered model to the newly created endpoint.\n",
"\n",
"Learn more about serving open models on Vertex AI using Hugging Face Deep Learning Container, check out [this tutorial](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_text_generation_inference_gemma.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vtn8g0q3ltVG"
},
"outputs": [],
"source": [
"gemma_model = aiplatform.Model.upload(\n",
" display_name=\"google--gemma-2-9b-it\",\n",
" serving_container_image_uri=\"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-text-generation-inference-cu124.2-3.ubuntu2204.py311\",\n",
" serving_container_environment_variables={\n",
" \"MODEL_ID\": \"google/gemma-2-9b-it\",\n",
" \"NUM_SHARD\": \"2\",\n",
" \"MAX_INPUT_TOKENS\": \"4095\",\n",
" \"MAX_TOTAL_TOKENS\": \"4096\",\n",
" \"MAX_BATCH_PREFILL_TOKENS\": \"4145\",\n",
" \"HUGGING_FACE_HUB_TOKEN\": get_token(),\n",
" },\n",
" serving_container_ports=[8080],\n",
")\n",
"gemma_model.wait()\n",
"\n",
"deployed_gemma_model = gemma_model.deploy(\n",
" endpoint=aiplatform.Endpoint.create(display_name=\"google--gemma-2-9b-it-endpoint\"),\n",
" machine_type=\"g2-standard-24\",\n",
" accelerator_type=\"NVIDIA_L4\",\n",
" accelerator_count=2,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rYRsYAlj1soI"
},
"source": [
"## Using Gemma 2 with Vertex AI Gen AI Evaluation\n",
"\n",
"To run evaluations using Gemma 2 with Vertex AI Gen AI Evaluation, you use the `EvalTask` class.\n",
"\n",
"The `EvalTask` requires an evaluation dataset (DataFrame, dictionary, or URI) and a list of [supported metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval). Datasets can use standard column names like prompt, reference, response, and baseline_model_response, customizable via parameters like response_column_name.\n",
"\n",
"EvalTask supports three scenarios: bring-your-own-response (BYOR), inference without a prompt template (using a prompt column), and inference with a prompt template (using columns matching template variables). And those scenarios are compatible with Gemini, 3P models, and custom functions, supporting various metrics.\n",
"\n",
"After defining your EvalTask, use `evaluate()` method to run the evaluation, optionally providing a model, prompt template, logging configuration, and other parameters. See the [Gen AI Evaluation package](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.evaluation) documentation for more details.\n",
"\n",
"This tutorial shows the two main ways to use Gemma 2 with Vertex AI Gen AI Evaluation:\n",
"\n",
"1. Gemma 2 as model to evaluate (`Evaluate Gemma 2` scenario)\n",
"2. Gemma 2 as model (`Gemma 2 as LLM-as-Judge` scenario)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9y1DC2oWJ3B_"
},
"source": [
"### Scenario 1: `Evaluate Gemma 2` for summarization\n",
"\n",
"To evaluate Gemma 2 for text summarization using Vertex AI Gen AI evaluation, cover the following steps:\n",
"\n",
"1. Prepare the dataset\n",
"2. Define a model function\n",
"3. Set a base prompt and metrics\n",
"4. Initiate an `EvalTask`\n",
"5. Run an evaluation job"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lQissTy93yet"
},
"source": [
"#### Prepare the evaluation dataset\n",
"\n",
"To start, prepare the evaluation dataset.\n",
"\n",
"The XSum dataset is loaded and preprocessed for evaluation. Documents and summaries longer than 4096 tokens are filtered out, columns are renamed to \"context\" and \"reference\", and the \"id\" column is removed.\n",
"\n",
"A random 10-sample subset is created for efficient evaluation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ranK-Np0zrEG"
},
"outputs": [],
"source": [
"eval_model_dataset = datasets.load_dataset(\"xsum\", split=\"test\", trust_remote_code=True)\n",
"\n",
"eval_model_dataset = (\n",
" eval_model_dataset.filter(lambda example: len(example[\"document\"]) < 2048)\n",
" .filter(lambda example: len(example[\"summary\"]) < 2048)\n",
" .rename_columns({\"document\": \"context\", \"summary\": \"reference\"})\n",
" .remove_columns([\"id\"])\n",
")\n",
"\n",
"n = 10 # @param {type: \"integer\", placeholder: \"10\", isTemplate: true}\n",
"eval_model_sample_df = (\n",
" eval_model_dataset.shuffle(seed=8)\n",
" .select(random.sample(range(0, len(eval_model_dataset)), n))\n",
" .to_pandas()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dZpLWBsSzSvZ"
},
"outputs": [],
"source": [
"eval_model_sample_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WTbWEZzmOpUc"
},
"source": [
"#### Define a model function\n",
"\n",
"Define a model function which is a wrapper to generate predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XyTJDWLBLJWj"
},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-9b-it\")\n",
"\n",
"generation_config = {\n",
" \"max_new_tokens\": 256,\n",
" \"do_sample\": True,\n",
" \"temperature\": 0.2,\n",
"}\n",
"\n",
"\n",
"def gemma_fn(prompt, generation_config=generation_config):\n",
" formatted_prompt = tokenizer.apply_chat_template(\n",
" [\n",
" {\"role\": \"user\", \"content\": prompt},\n",
" ],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
" )\n",
"\n",
" instance = {\"inputs\": formatted_prompt, \"parameters\": generation_config}\n",
" output = deployed_gemma_model.predict(instances=[instance])\n",
" generated_text = output.predictions[0]\n",
" return generated_text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rrGh1K_S_uTk"
},
"source": [
"#### Set base prompt and metrics to evaluate your task\n",
"\n",
"Define the prompt template and metrics to use to evaluate the summarization task. Vertex AI Gen AI Evalutions provides several metric prompt templates for model-based evaluation you can use. Check out [the documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/models/metrics-templates) to know more."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0MgvVQY2O3Hu"
},
"outputs": [],
"source": [
"prompt_template = (\n",
" \"Summarize the following article in one sentence: {context}.\\nSummary:\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7cBx1kZF_1G6"
},
"outputs": [],
"source": [
"metrics = [\"rouge_l_sum\", \"summarization_quality\", \"fluency\"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8iWfF1DPqVQU"
},
"source": [
"#### Run the evaluation\n",
"\n",
"To run evaluations for prompt templates, you run an evaluation job repeatedly against an evaluation dataset and its associated metrics. With EvalTask, you leverage integration with Vertex AI Experiments to track settings and results for each evaluation run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_HUstMFvqUwl"
},
"outputs": [],
"source": [
"run_id = generate_uuid()\n",
"experiment_name = \"eval-gemma-base-prompt-sum\"\n",
"experiment_run_name = f\"{experiment_name}-{run_id}\"\n",
"\n",
"eval_task = EvalTask(\n",
" dataset=eval_model_sample_df,\n",
" metrics=metrics,\n",
" experiment=experiment_name,\n",
")\n",
"\n",
"eval_result = eval_task.evaluate(\n",
" model=gemma_fn,\n",
" prompt_template=prompt_template,\n",
" experiment_run_name=experiment_run_name,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3gTNR5s4reWQ"
},
"source": [
"#### Display Evaluation reports and explanations\n",
"\n",
"Display detailed evaluation reports, explanations, and useful charts to summarize key metrics in an informative manner."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eMHH_R83pZ3S"
},
"outputs": [],
"source": [
"display_eval_report(eval_result, \"Gemma 2 evaluation report\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6Vft3PqFTsZR"
},
"outputs": [],
"source": [
"display_explanations(eval_result.metrics_table, metrics=[\"fluency\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Kr9eCzXLHRmJ"
},
"outputs": [],
"source": [
"plot_bar_plot(\n",
" eval_result,\n",
" title=\"Evaluate Gemma 2\",\n",
" metrics=[\"summarization_quality/mean\", \"fluency/mean\"],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FgkKr6RfUHsf"
},
"source": [
"### Scenario 2: `Gemma 2 as LLM-as-Judge` to evaluate generated summaries\n",
"\n",
"To use Gemma 2 as LLM-as-Judge for text summarization using Vertex AI Gen AI evaluation, cover the following steps:\n",
"\n",
"1. Define a Model function (see above)\n",
"2. Define a Custom metric to set Gemma 2 as autorater\n",
"3. Initiate an `EvalTask`\n",
"4. Run an evaluation job"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vye6aiQ3t1Je"
},
"source": [
"#### Prepare the dataset\n",
"\n",
"In this scenario, generate summaries to evaluate using Gemini API on Vertex AI by leveraging concurrent prediction requests for increased efficiency. This approach is particularly useful when evaluating against a large dataset of summaries.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PJVyFaTwt3Rr"
},
"outputs": [],
"source": [
"eval_model_sample_df[\"prompt\"] = eval_model_sample_df.apply(\n",
" lambda row: prompt_template.format(context=row[\"context\"]), axis=1\n",
")\n",
"gemini_llm = init_new_model(model_name=\"gemini-2.0-flash\")\n",
"gemini_predictions = [\n",
" async_generate(p, model=gemini_llm) for p in eval_model_sample_df[\"prompt\"]\n",
"]\n",
"gemini_predictions_col = await tqdm_asyncio.gather(*gemini_predictions)\n",
"eval_model_sample_df[\"response\"] = gemini_predictions_col"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_D6iCWcvuIFj"
},
"outputs": [],
"source": [
"eval_model_sample_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D-5UUIx2UHsg"
},
"source": [
"#### Define a metric function to use Gemma 2 as an evaluator\n",
"\n",
"Define a custom model-based metric function, `catchiness_fn` in this case, to evaluate the \"catchiness\" of an AI-generated response given a user prompt.\n",
"\n",
"It uses a Gemma 2 as an evaluator based on the detailed prompt template."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4QV-jW7IcItT"
},
"outputs": [],
"source": [
"def catchiness_fn(instance: dict) -> dict:\n",
"\n",
" metric_prompt_template = \"\"\"\n",
"\n",
"# Instruction\n",
"You are an expert evaluator. Your task is to evaluate the catchiness of responses generated by AI models.\n",
"We will provide you with the user input (prompt) and an AI-generated response.\n",
"You should first read the user input carefully to understand the task, and then evaluate the catchiness of the response based on the criteria provided in the Evaluation section below.\n",
"Then you will assign the response a score (float) and a explanation (string) following the Rating Rubric and Evaluation Steps. Give step-by-step explanations for your rating, and only choose ratings from the Rating Rubric.\n",
"Finally ONLY return the score and the explanation in a JSON as shown in Examples section.\n",
"\n",
"# Evaluation\n",
"\n",
"## Metric Definition\n",
"Catchiness: The response uses creative language, vivid imagery, memorable phrasing, and a compelling tone to create a lasting impression on the reader. It might employ techniques like humor, wordplay, or strong emotional appeals. It should be relevant to the prompt and avoid being overly repetitive or generic.\n",
"\n",
"## Criteria\n",
"* **Creative Language:** Does the response utilize figurative language (metaphors, similes, personification, etc.), evocative descriptions, and interesting vocabulary? Is the language fresh and original?\n",
"* **Memorable Phrasing:** Does the response contain turns of phrase, slogans, or other linguistic devices that stick with the reader? Are there any particularly quotable lines?\n",
"* **Relevance:** Is the catchiness relevant to the response? Does it enhance the core message or distract from it?\n",
"\n",
"## Rating Rubric\n",
"5: (Exceptionally Catchy) The response is highly creative, uses vivid imagery and memorable phrasing, and maintains a compelling tone. It leaves a strong and lasting impression. It is perfectly relevant to the prompt and avoids generic language.\n",
"4: (Very Catchy) The response is creative and engaging, with clear use of imagery and memorable phrasing. It leaves a positive impression. It is relevant to the prompt and mostly avoids generic language.\n",
"3: (Moderately Catchy) The response shows some creativity and uses some imagery and memorable phrasing, but the impact is less pronounced. It is relevant to the prompt but might contain some generic language.\n",
"2: (Slightly Catchy) The response demonstrates limited creativity and uses minimal imagery or memorable phrasing. The impact is weak. It might be somewhat relevant to the prompt and contains quite a bit of generic language.\n",
"1: (Not Catchy) The response lacks creativity, vivid imagery, and memorable phrasing. It leaves no lasting impression. It might be irrelevant to the prompt and relies heavily on generic language.\n",
"\n",
"## Evaluation Steps\n",
"STEP 1: Assess Creative Language: Identify the use of figurative language, evocative descriptions, and interesting vocabulary. Judge the originality and freshness of the language.\n",
"STEP 2: Assess Vivid Imagery and Memorable Phrasing: Analyze the use of sensory details and identify any phrases or lines that are particularly memorable or quotable.\n",
"STEP 3: Assess Compelling Tone and Relevance: Determine the tone of the response and evaluate its appropriateness for the prompt and target audience. Assess the relevance of the catchy elements to the core message of the prompt.\n",
"STEP 4: Assess Avoidance of Repetition and Generic Language: Identify any instances of clichés, overused phrases, or repetitive sentence structures.\n",
"\n",
"# User Inputs and AI-generated Response\n",
"## User Inputs\n",
"### Prompt\n",
"{prompt}\n",
"\n",
"## AI-generated Response\n",
"{response}\n",
"\n",
"# Examples\n",
"```json {{\"score\": 5, \"explanation\": \"The summary is perfectly relevant to the prompt, highly creative, and avoids generic language. It uses vivid imagery, memorable phrasing, and a compelling tone to leave a strong and lasting impression.\"}} ```\n",
"```json {{\"score\": 3, \"explanation\": \"The summary is relevant to the prompt, but the impact is somewhat lessened by generic language and less vivid imagery and phrasing. It shows some creativity, though.\"}} ```\n",
"```json {{\"score\": 1, \"explanation\": \"The summary is irrelevant, lacks creativity, and fails to make a lasting impression. It relies on generic language and lacks vivid imagery or memorable phrasing.\"}} ```\n",
"\n",
"# Evaluation JSON:\n",
"\"\"\"\n",
"\n",
" default_result = {\"catchiness\": 0, \"explanation\": \"\"}\n",
"\n",
" def parse_json_output(json_string: str) -> dict:\n",
" \"\"\"Parses JSON output and extracts score and explanation.\"\"\"\n",
" try:\n",
" # Clean JSON string more robustly\n",
" cleaned_json = (\n",
" json_string.strip().removeprefix(\"```json\").removesuffix(\"```\")\n",
" )\n",
" data = json.loads(cleaned_json)\n",
"\n",
" return {\n",
" \"catchiness\": data.get(\"score\", 0),\n",
" \"explanation\": data.get(\"explanation\", \"\"),\n",
" }\n",
" except json.JSONDecodeError:\n",
" return default_result\n",
"\n",
" try:\n",
" # Input validation\n",
" if not isinstance(instance, dict) or not all(\n",
" k in instance for k in [\"prompt\", \"response\"]\n",
" ):\n",
" raise ValueError(\n",
" \"Instance must be a dict with 'prompt' and 'response' keys\"\n",
" )\n",
"\n",
" metric_prompt = metric_prompt_template.format(\n",
" prompt=instance[\"prompt\"], response=instance[\"response\"]\n",
" )\n",
"\n",
" rater_config = {\"max_new_tokens\": 256, \"temperature\": 0}\n",
"\n",
" eval_response = gemma_fn(metric_prompt, rater_config)\n",
" return parse_json_output(eval_response)\n",
"\n",
" except Exception as e:\n",
" return default_result"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NrFhJS5mzp1w"
},
"source": [
"Create a `Custom Metric` instance to evaluate generated summaries using Vertex AI Gen AI evaluation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "S_Jztdwpe0UT"
},
"outputs": [],
"source": [
"catchiness_metric = CustomMetric(\n",
" name=\"catchiness\",\n",
" metric_function=catchiness_fn,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "djQXvQobUHsg"
},
"source": [
"#### Set metrics to evaluate your task\n",
"\n",
"Define metrics to use to evaluate the summarization task."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oLBTAvGNUHsg"
},
"outputs": [],
"source": [
"metrics = [\"rouge_l_sum\", \"fluency\", catchiness_metric]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fx4R8AwZUHsh"
},
"source": [
"#### Run the evaluation\n",
"\n",
"To run the evaluation job using Gemma as an evaluator."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NLdbGYHpUHsh"
},
"outputs": [],
"source": [
"run_id = generate_uuid()\n",
"experiment_name = \"gemma-judge-base-prompt-sum\"\n",
"experiment_run_name = f\"{experiment_name}-{run_id}\"\n",
"\n",
"eval_task = EvalTask(\n",
" dataset=eval_model_sample_df,\n",
" metrics=metrics,\n",
" experiment=experiment_name,\n",
")\n",
"\n",
"eval_result = eval_task.evaluate(\n",
" experiment_run_name=experiment_run_name,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vxMp1eGwUHsh"
},
"source": [
"#### Display Evaluation reports and explanations\n",
"\n",
"Visualize reports and useful charts to evaluate the model in summarization task."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bTfYwO45UHsh"
},
"outputs": [],
"source": [
"display_eval_report(eval_result, \"Gemma 2 Judging result\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "axLGicoGUHsh"
},
"outputs": [],
"source": [
"display_explanations(eval_result.metrics_table, metrics=[\"catchiness\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NaAei_mCxcVt"
},
"outputs": [],
"source": [
"plot_bar_plot(\n",
" eval_result,\n",
" title=\"Evaluate Gemini using Gemma 2\",\n",
" metrics=[\"fluency/mean\", \"catchiness/mean\"],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d7cce28cc97e"
},
"source": [
"## Cleaning up"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ef73672573e1"
},
"outputs": [],
"source": [
"delete_experiment = False # @param {type:\"boolean\", isTemplate: false}\n",
"\n",
"if delete_experiment:\n",
" from google.cloud import aiplatform\n",
"\n",
" aiplatform.init(project=PROJECT_ID, location=LOCATION)\n",
" for experiment_name in [\n",
" \"eval-gemma-base-prompt-sum\",\n",
" \"gemma-judge-base-prompt-sum\",\n",
" ]:\n",
" experiment = aiplatform.Experiment(experiment_name=experiment_name)\n",
" experiment.delete()"
]
}
],
"metadata": {
"colab": {
"name": "vertex_ai_tgi_gemma_with_genai_evaluation.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}