notebooks/community/model_garden/model_garden_shieldgemma2_local_inference.ipynb (228 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "64H_Y4stWy1s"
},
"outputs": [],
"source": [
"# Copyright 2025 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jw8YVWK573xh"
},
"source": [
"# Vertex AI Model Garden - ShieldGemma 2 Local Inference\n",
"\n",
"\n",
"<table><tbody><tr>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_shieldgemma2_local_inference.ipynb\">\n",
" <img alt=\"Google Cloud Colab Enterprise logo\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" width=\"32px\"><br> Run in Colab Enterprise\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_shieldgemma2_local_inference.ipynb\">\n",
" <img alt=\"GitHub logo\" src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" width=\"32px\"><br> View on GitHub\n",
" </a>\n",
" </td>\n",
"</tr></tbody></table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JKsp-Ckg73xh"
},
"source": [
"## Overview\n",
"\n",
"This notebook demonstrates how to install the necessary libraries and run local inference with the ShieldGemma 2 model in a [Colab Enterprise Instance](https://cloud.google.com/colab/docs).\n",
"\n",
"The **ShieldGemma 2** model is trained to detect key harms detailed in the [model card](https://ai.google.dev/gemma/docs/shieldgemma/model_card_2). This guide demonstrates how to use Hugging Face Transformers to build robust data and models.\n",
"\n",
"Note that `ShieldGemma 2` is trained to classify only one harm type at a time, so you will need to make a separate call to `ShieldGemma 2` for each harm type you want to check against. You may have additional that you can use model tuning techniques on `ShieldGemma 2`.\n",
"\n",
"### ShieldGemma 2 Model Licensing\n",
"\n",
"* ShieldGemma 2 Model is available under the [Gemma Terms of Use](https://ai.google.dev/gemma/terms). For full details, refer to the linked documentation that governs the use of this model.\n",
"\n",
"\n",
"### Objective\n",
"\n",
"* Run local inference with the ShieldGemma 2 model for image safety classification.\n",
"\n",
"\n",
"### Costs\n",
"\n",
"This tutorial uses billable components of Google Cloud:\n",
"\n",
"* Vertex AI\n",
"\n",
"Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q2wFKYqyVYQM"
},
"source": [
"## Supported safety checks\n",
"\n",
"**ShieldGemma2** is a model trained on Gemma 3's 4B IT checkpoint and is trained to detect and predict violations of key harm types listed below:\n",
"\n",
"* **Dangerous Content**: The image shall not contain content that facilitates or encourages activities that could cause real-world harm (e.g., building firearms and explosive devices, promotion of terrorism, instructions for suicide).\n",
"\n",
"* **Sexually Explicit**: The image shall not contain content that depicts explicit or graphic sexual acts (e.g., pornography, erotic nudity, depictions of rape or sexual assault).\n",
"\n",
"* **Violence/Gore**: The image shall not contain content that depicts shocking, sensational, or gratuitous violence (e.g., excessive blood and gore, gratuitous violence against animals, extreme injury or moment of death).\n",
"\n",
"This serves as a foundation, but users can provide customized safety policies as input to the model, allowing for fine-grained control and specific use-case requirements."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Q7PzjwK73xh"
},
"source": [
"## Before you begin\n",
"\n",
"Make sure you are connecting to a [Colab Enterprise runtime](https://cloud.google.com/colab/docs/connect-to-runtime) with GPU. If not, we recommend [creating a runtime template](https://cloud.google.com/colab/docs/create-runtime-template) with 1 `NVIDIA_L4` GPU (e.g. with the machine type `g2-standard-8` or similar)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "DonjAWGYF2hp"
},
"outputs": [],
"source": [
"# @title Access the ShieldGemma 2 model\n",
"\n",
"# @markdown 1. Go to the [ShieldGemma 2 model card on Hugging Face](https://huggingface.co/gg-hf-g/shieldgemma-2-4b-it) and accept the agreement if not already.\n",
"\n",
"# @markdown 2. Provide a Hugging Face User Access Token (read) below to access the ShieldGemma 2 model.\n",
"# @markdown You can follow the [Hugging Face documentation](https://huggingface.co/docs/hub/en/security-tokens) to create a **read** access token.\n",
"\n",
"HF_TOKEN = \"\" # @param {type:\"string\", isTemplate:true}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "tdukqhXhVYQM"
},
"outputs": [],
"source": [
"# @title Install transformers with ShieldGemma 2 support\n",
"\n",
"! pip install git+https://github.com/huggingface/transformers@v4.50.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u3ym-XHH73xh"
},
"source": [
"## Image Safety Classification with ShieldGemma 2\n",
"\n",
"The following code classifies images for safety with the ShieldGemma 2 model.\n",
"\n",
"**Input:**\n",
"\n",
"Image + Prompt Instruction with policy definition.\n",
"\n",
"**Output:**\n",
"\n",
"Probability of `Yes`/`No` tokens, with a higher score indicating the model's higher confidence that the image violates the specified policy. `Yes` means that the image violated the policy, `No` means that the model did not violate the policy.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "8Oc4G3b573xh"
},
"outputs": [],
"source": [
"# @title Load model weights\n",
"import torch\n",
"from transformers import AutoProcessor, ShieldGemma2ForImageClassification\n",
"\n",
"model_id = \"google/shieldgemma-2-4b-it\"\n",
"\n",
"processor = AutoProcessor.from_pretrained(model_id, token=HF_TOKEN)\n",
"model = ShieldGemma2ForImageClassification.from_pretrained(\n",
" model_id,\n",
" torch_dtype=torch.bfloat16,\n",
" token=HF_TOKEN,\n",
")\n",
"model = model.to(torch.device(\"cuda\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "KjoRriS373xh"
},
"outputs": [],
"source": [
"import requests\n",
"# @title Run local inference with input image\n",
"from PIL import Image\n",
"\n",
"# @markdown URL to the input image.\n",
"image_url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG\" # @param {type: \"string\"}\n",
"# @markdown Click \"Show Code\" to see the more details.\n",
"\n",
"image = Image.open(requests.get(image_url, stream=True).raw)\n",
"model_inputs = processor(images=[image], return_tensors=\"pt\").to(model.device)\n",
"\n",
"with torch.inference_mode():\n",
" scores = model(**model_inputs)\n",
"\n",
"print(scores.probabilities)"
]
}
],
"metadata": {
"colab": {
"name": "model_garden_shieldgemma2_local_inference.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}