gemini/responsible-ai/gemini_safety_ratings.ipynb (680 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "ur8xi4C7S06n" }, "outputs": [], "source": [ "# Copyright 2024 Google LLC\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "JAPoU8Sm5E6e" }, "source": [ "# Responsible AI with Gemini API in Vertex AI: Safety ratings and thresholds\n", "\n", "<table align=\"left\">\n", " <td style=\"text-align: center\">\n", " <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/responsible-ai/gemini_safety_ratings.ipynb\">\n", " <img width=\"32px\" src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Google Colaboratory logo\"><br> Run in Colab\n", " </a>\n", " </td>\n", " <td style=\"text-align: center\">\n", " <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fresponsible-ai%2Fgemini_safety_ratings.ipynb\">\n", " <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Run in Colab Enterprise\n", " </a>\n", " </td> \n", " <td style=\"text-align: center\">\n", " <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/responsible-ai/gemini_safety_ratings.ipynb\">\n", " <img width=\"32px\" src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\"><br> View on GitHub\n", " </a>\n", " </td>\n", " <td style=\"text-align: center\">\n", " <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/responsible-ai/gemini_safety_ratings.ipynb\">\n", " <img width=\"32px\" src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br>\n", " Open in Vertex AI Workbench\n", " </a>\n", " </td> \n", "</table>\n", "\n", "<div style=\"clear: both;\"></div>\n", "\n", "<b>Share to:</b>\n", "\n", "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/responsible-ai/gemini_safety_ratings.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n", "</a>\n", "\n", "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/responsible-ai/gemini_safety_ratings.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n", "</a>\n", "\n", "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/responsible-ai/gemini_safety_ratings.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n", "</a>\n", "\n", "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/responsible-ai/gemini_safety_ratings.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n", "</a>\n", "\n", "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/responsible-ai/gemini_safety_ratings.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n", "</a> " ] }, { "cell_type": "markdown", "metadata": { "id": "D7Isll3-PJQ1" }, "source": [ "| | |\n", "|-|-|\n", "|Author(s) | [Hussain Chinoy](https://github.com/ghchinoy), [Holt Skinner](https://github.com/holtskinner) |" ] }, { "cell_type": "markdown", "metadata": { "id": "tvgnzT1CKxrO" }, "source": [ "## Overview\n", "\n", "Large language models (LLMs) can translate language, summarize text, generate creative writing, generate code, power chatbots and virtual assistants, and complement search engines and recommendation systems. The incredible versatility of LLMs is also what makes it difficult to predict exactly what kinds of unintended or unforeseen outputs they might produce.\n", "\n", "Given these risks and complexities, the Gemini API in Vertex AI is designed with [Google's AI Principles](https://ai.google/responsibility/principles/) in mind. However, it is important for developers to understand and test their models to deploy safely and responsibly. To aid developers, Vertex AI Studio has built-in content filtering, safety ratings, and the ability to define safety filter thresholds that are right for their use cases and business.\n", "\n", "For more information, see the [Google Cloud Generative AI documentation on Responsible AI](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai) and [Configuring safety attributes](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes)." ] }, { "cell_type": "markdown", "metadata": { "id": "d975e698c9a4" }, "source": [ "### Objectives\n", "\n", "In this tutorial, you learn how to inspect the safety ratings returned from the Gemini API in Vertex AI using the Python SDK and how to set a safety threshold to filter responses from the Gemini API in Vertex AI.\n", "\n", "The steps performed include:\n", "\n", "- Call the Gemini API in Vertex AI and inspect safety ratings of the responses\n", "- Define a threshold for filtering safety ratings according to your needs" ] }, { "cell_type": "markdown", "metadata": { "id": "aed92deeb4a0" }, "source": [ "### Costs\n", "\n", "This tutorial uses billable components of Google Cloud:\n", "\n", "- Vertex AI\n", "\n", "Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "yjg3mPMSPJQ7" }, "source": [ "## Getting Started\n" ] }, { "cell_type": "markdown", "metadata": { "id": "HDBMQEnXsnRB" }, "source": [ "### Install Google Gen AI SDK for Python\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SBUtvsQHPJQ8" }, "outputs": [], "source": [ "%pip install --upgrade --quiet google-genai" ] }, { "cell_type": "markdown", "metadata": { "id": "R5Xep4W9lq-Z" }, "source": [ "### Restart current runtime\n", "\n", "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which will restart the current kernel." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XRvKdaPDTznN" }, "outputs": [], "source": [ "# Restart kernel after installs so that your environment can access the new packages\n", "import IPython\n", "\n", "app = IPython.Application.instance()\n", "app.kernel.do_shutdown(True)" ] }, { "cell_type": "markdown", "metadata": { "id": "SbmM4z7FOBpM" }, "source": [ "<div class=\"alert alert-block alert-warning\">\n", "<b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️</b>\n", "</div>\n" ] }, { "cell_type": "markdown", "metadata": { "id": "sBCra4QMA2wR" }, "source": [ "### Authenticate your notebook environment (Colab only)\n", "\n", "If you are running this notebook on Google Colab, run the following cell to authenticate your environment. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "254614fa0c46" }, "outputs": [], "source": [ "import sys\n", "\n", "# Additional authentication is required for Google Colab\n", "if \"google.colab\" in sys.modules:\n", " # Authenticate user to Google Cloud\n", " from google.colab import auth\n", "\n", " auth.authenticate_user()" ] }, { "cell_type": "markdown", "metadata": { "id": "ef21552ccea8" }, "source": [ "### Set Google Cloud project information and create client\n", "\n", "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n", "\n", "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "603adbbf0532" }, "outputs": [], "source": [ "import os\n", "\n", "PROJECT_ID = \"[your-project-id]\" # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n", "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n", " PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n", "\n", "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "e0047cf34fe7" }, "outputs": [], "source": [ "from google import genai\n", "\n", "client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)" ] }, { "cell_type": "markdown", "metadata": { "id": "i7EUnXsZhAGF" }, "source": [ "### Import libraries\n" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "eeH2sddasR1a" }, "outputs": [], "source": [ "from IPython.display import Markdown, display\n", "from google.genai.types import (\n", " GenerateContentConfig,\n", " GenerateContentResponse,\n", " SafetySetting,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "dcc5ea7e29c9" }, "source": [ "### Helper functions" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "id": "8096b9de8383" }, "outputs": [], "source": [ "def print_safety_ratings(response: GenerateContentResponse) -> None:\n", " \"\"\"Displays safety ratings and related information in Markdown format.\"\"\"\n", " display(Markdown(\"### Safety Ratings\\n\"))\n", "\n", " if response.prompt_feedback:\n", " display(Markdown(f\"**Prompt Feedback:** {response.prompt_feedback}\"))\n", "\n", " candidate = response.candidates[0]\n", "\n", " table_header = (\n", " \"| Blocked | Category | Probability | Probability Score | Severity | Severity Score |\\n\"\n", " \"|---|---|---|---|---|---|\\n\"\n", " )\n", "\n", " table_rows = \"\\n\".join(\n", " f\"| {'✅' if not rating.blocked else '❌'} | `{rating.category}` | `{rating.probability}` | \"\n", " f\"`{rating.probability_score}` | `{rating.severity}` | `{rating.severity_score}` |\"\n", " for rating in candidate.safety_ratings\n", " )\n", "\n", " display(Markdown(table_header + table_rows))\n", "\n", " # Display finish reason and message if they exist\n", " if candidate.finish_reason:\n", " display(Markdown(f\"**Finish Reason:** `{candidate.finish_reason}`\"))\n", " if candidate.finish_message:\n", " display(Markdown(f\"**Finish Message:** `{candidate.error_message}`\"))" ] }, { "cell_type": "markdown", "metadata": { "id": "5rpgrqQrPJQ-" }, "source": [ "### Load the Gemini model\n", "\n", "Learn more about all [Gemini models on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models)." ] }, { "cell_type": "code", "execution_count": 72, "metadata": { "id": "5X9BCtm2PJQ-" }, "outputs": [], "source": [ "MODEL_ID = \"gemini-2.0-flash-001\" # @param {type: \"string\"}\n", "\n", "# Set parameters to reduce variability in responses\n", "generation_config = GenerateContentConfig(\n", " temperature=0,\n", " top_p=0.1,\n", " top_k=1,\n", " max_output_tokens=1024,\n", " seed=1,\n", " candidate_count=1,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "HlHF7Oqw0zBc" }, "source": [ "## Generate text and show safety ratings" ] }, { "cell_type": "markdown", "metadata": { "id": "u7wSHFUtV48I" }, "source": [ "Start by generating a pleasant-sounding text response using Gemini." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "i-fAS7XV05Bp" }, "outputs": [], "source": [ "# Call Gemini API\n", "nice_prompt = \"Say three nice things.\"\n", "\n", "response = client.models.generate_content(\n", " model=MODEL_ID, config=generation_config, contents=nice_prompt\n", ")\n", "\n", "display(Markdown(response.text))" ] }, { "cell_type": "markdown", "metadata": { "id": "qXmMAbg0PJQ_" }, "source": [ "#### Inspecting the safety ratings" ] }, { "cell_type": "markdown", "metadata": { "id": "8EPQRdiG1BVv" }, "source": [ "Look at the `safety_ratings` of the response." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1z82p_bPSK5p" }, "outputs": [], "source": [ "print_safety_ratings(response)" ] }, { "cell_type": "markdown", "metadata": { "id": "71N4sjLtPJQ_" }, "source": [ "#### Understanding the safety ratings: category and probability" ] }, { "cell_type": "markdown", "metadata": { "id": "8bd5SnfOSR0n" }, "source": [ "You can see the safety ratings, including each `category` type and its associated `probability` label, as well as a `probability_score`. Additionally, safety ratings have been expanded to `severity` and `severity_score`.\n", "\n", "The `category` types include:\n", "\n", "* Hate speech: `HARM_CATEGORY_HATE_SPEECH`\n", "* Dangerous content: `HARM_CATEGORY_DANGEROUS_CONTENT`\n", "* Harassment: `HARM_CATEGORY_HARASSMENT`\n", "* Sexually explicit statements: `HARM_CATEGORY_SEXUALLY_EXPLICIT`\n", "\n", "The `probability` labels are:\n", "\n", "* `NEGLIGIBLE` - content has a negligible probability of being unsafe\n", "* `LOW` - content has a low probability of being unsafe\n", "* `MEDIUM` - content has a medium probability of being unsafe\n", "* `HIGH` - content has a high probability of being unsafe\n", "\n", "The `probability_score` has an associated confidence score between `0.0` and `1.0`.\n", "\n", "Each of the four safety attributes is assigned a safety rating (severity level) and a severity score ranging from `0.0` to `1.0`. The ratings and scores in the following table reflect the predicted severity of the content belonging to a given category." ] }, { "cell_type": "markdown", "metadata": { "id": "ncwjPVYfk19K" }, "source": [ "#### Comparing Probability and Severity\n", "\n", "There are two types of safety scores:\n", "\n", "* Safety scores based on **probability** of being unsafe\n", "* Safety scores based on **severity** of harmful content\n", "\n", "The probability safety attribute reflects the likelihood that an input or model response is associated with the respective safety attribute. The severity safety attribute reflects the magnitude of how harmful an input or model response might be.\n", "\n", "Content can have a low probability score and a high severity score, or a high probability score and a low severity score. For example, consider the following two sentences:\n", "\n", "- The robot punched me.\n", "- The robot slashed me up.\n", "\n", "The first sentence might cause a higher probability of being unsafe and the second sentence might have a higher severity in terms of violence. Because of this, it's important to carefully test and consider the appropriate level of blocking required to support your key use cases and also minimize harm to end users.\n", "\n", "#### Blocked responses\n", "\n", "If the response is blocked, you will see that the final candidate includes `blocked: True`, and also observe which of the safety ratings triggered the blocking of the response (e.g. `finish_reason: SAFETY`)." ] }, { "cell_type": "markdown", "metadata": { "id": "k0rlZEpGPJRA" }, "source": [ "Try a prompt that might trigger one of these categories:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pcw5s7Jo1Axm" }, "outputs": [], "source": [ "impolite_prompt = \"Write a list of 5 disrespectful things that I might say to the universe after stubbing my toe in the dark. Respond using profanity.\"\n", "\n", "responses = client.models.generate_content_stream(\n", " model=MODEL_ID, config=generation_config, contents=impolite_prompt\n", ")\n", "\n", "for response in responses:\n", " if response.text:\n", " print(response.text, end=\"\")\n", " else:\n", " print_safety_ratings(response)" ] }, { "cell_type": "markdown", "metadata": { "id": "zrPLIhgZ4etq" }, "source": [ "### Defining thresholds for safety ratings\n", "\n", "You may want to adjust the default safety filter thresholds depending on your business policies or use case. The Gemini API in Vertex AI provides you a way to pass in a threshold for each category.\n", "\n", "The list below shows the possible threshold labels:\n", "\n", "* `BLOCK_ONLY_HIGH` - block when high probability of unsafe content is detected\n", "* `BLOCK_MEDIUM_AND_ABOVE` - block when medium or high probability of content is detected\n", "* `BLOCK_LOW_AND_ABOVE` - block when low, medium, or high probability of unsafe content is detected\n", "* `BLOCK_NONE` - always show, regardless of probability of unsafe content" ] }, { "cell_type": "markdown", "metadata": { "id": "oYGKVnGePJRB" }, "source": [ "#### Set safety thresholds\n", "Below, the safety thresholds have been set to the most sensitive threshold: `BLOCK_LOW_AND_ABOVE`" ] }, { "cell_type": "code", "execution_count": 79, "metadata": { "id": "T0YohSf1PJRB" }, "outputs": [], "source": [ "generation_config.safety_settings = [\n", " SafetySetting(\n", " category=\"HARM_CATEGORY_DANGEROUS_CONTENT\", threshold=\"BLOCK_LOW_AND_ABOVE\"\n", " ),\n", " SafetySetting(\n", " category=\"HARM_CATEGORY_HATE_SPEECH\", threshold=\"BLOCK_LOW_AND_ABOVE\"\n", " ),\n", " SafetySetting(category=\"HARM_CATEGORY_HARASSMENT\", threshold=\"BLOCK_LOW_AND_ABOVE\"),\n", " SafetySetting(\n", " category=\"HARM_CATEGORY_SEXUALLY_EXPLICIT\", threshold=\"BLOCK_LOW_AND_ABOVE\"\n", " ),\n", "]" ] }, { "cell_type": "markdown", "metadata": { "id": "2tHldASqPJRB" }, "source": [ "#### Test thresholds\n", "\n", "Here you will reuse the impolite prompt from earlier together with the most sensitive safety threshold. It should block the response even with the `LOW` probability label." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vq3at7EmPJRB" }, "outputs": [], "source": [ "impolite_prompt = \"Write a list of 5 disrespectful things that I might say to the universe after stubbing my toe in the dark:\"\n", "\n", "responses = client.models.generate_content_stream(\n", " model=MODEL_ID, config=generation_config, contents=impolite_prompt\n", ")\n", "\n", "for response in responses:\n", " if response.text:\n", " print(response.text, end=\"\")\n", " else:\n", " print_safety_ratings(response)" ] }, { "cell_type": "markdown", "metadata": { "id": "mYudAfc6gDi8" }, "source": [ "Let's look at how we understand block responses in the next section." ] }, { "cell_type": "markdown", "metadata": { "id": "l2v6VnECf-fC" }, "source": [ "## Understanding Blocked Responses\n", "\n", "The documentation for [`FinishReason`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse#finishreason) contains some more detailed explanations.\n", "\n", "For example, the previous response was blocked with the `finish_reason: SAFETY`, indicating that\n", "\n", "> The token generation was stopped as the response was flagged for safety reasons. NOTE: The `response.text` will be empty if content filters blocked the output.\n", "\n", "As of this writing, the table from the `FinishReason` have been reproduced below, but please look at the docs for the definitive explanations\n" ] }, { "cell_type": "markdown", "metadata": { "id": "FhbbwYhJijfa" }, "source": [ "Finish Reason | Explanation\n", "--- | ---\n", "`FINISH_REASON_UNSPECIFIED`\t| The finish reason is unspecified.\n", "`STOP` | Natural stop point of the model or provided stop sequence.\n", "`MAX_TOKENS` | The maximum number of tokens as specified in the request was reached.\n", "`SAFETY` | The token generation was stopped as the response was flagged for safety reasons. |\n", "`RECITATION` | The token generation was stopped as the response was flagged for unauthorized citations.\n", "`OTHER` | All other reasons that stopped the token generation\n", "`BLOCKLIST` | The token generation was stopped as the response was flagged for the terms which are included from the terminology blocklist.\n", "`PROHIBITED_CONTENT` | The token generation was stopped as the response was flagged for the prohibited contents.\n", "`SPII` | The token generation was stopped as the response was flagged for Sensitive Personally Identifiable Information (SPII) contents.\n", "`MALFORMED_FUNCTION_CALL` | The function call generated by the model is invalid." ] } ], "metadata": { "colab": { "name": "gemini_safety_ratings.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }