gemini/controlled-generation/intro_controlled_generation.ipynb (746 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "ur8xi4C7S06n" }, "outputs": [], "source": [ "# Copyright 2024 Google LLC\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "JAPoU8Sm5E6e" }, "source": [ "# Intro to Controlled Generation with the Gemini API\n", "\n", "<table align=\"left\">\n", " <td style=\"text-align: center\">\n", " <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/controlled-generation/intro_controlled_generation.ipynb\">\n", " <img src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n", " </a>\n", " </td>\n", " <td style=\"text-align: center\">\n", " <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fcontrolled-generation%2Fintro_controlled_generation.ipynb\">\n", " <img width=\"32px\" src=\"https://cloud.google.com/ml-engine/images/colab-enterprise-logo-32px.png\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n", " </a>\n", " </td> \n", " <td style=\"text-align: center\">\n", " <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/controlled-generation/intro_controlled_generation.ipynb\">\n", " <img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br> Open in Workbench\n", " </a>\n", " </td>\n", " <td style=\"text-align: center\">\n", " <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/controlled-generation/intro_controlled_generation.ipynb\">\n", " <img src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\"><br> View on GitHub\n", " </a>\n", " </td>\n", " <td style=\"text-align: center\">\n", " <a href=\"https://goo.gle/3Pyftqr\">\n", " <img width=\"32px\" src=\"https://cdn.qwiklabs.com/assets/gcp_cloud-e3a77215f0b8bfa9b3f611c0d2208c7e8708ed31.svg\" alt=\"Google Cloud logo\"><br> Open in Cloud Skills Boost\n", " </a>\n", " </td>\n", "</table>\n", "\n", "<div style=\"clear: both;\"></div>\n", "\n", "<b>Share to:</b>\n", "\n", "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/controlled-generation/intro_controlled_generation.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n", "</a>\n", "\n", "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/controlled-generation/intro_controlled_generation.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n", "</a>\n", "\n", "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/controlled-generation/intro_controlled_generation.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n", "</a>\n", "\n", "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/controlled-generation/intro_controlled_generation.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n", "</a>\n", "\n", "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/controlled-generation/intro_controlled_generation.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n", "</a> " ] }, { "cell_type": "markdown", "metadata": { "id": "84f0f73a0f76" }, "source": [ "| | |\n", "|-|-|\n", "|Author(s) | [Eric Dong](https://github.com/gericdong)|" ] }, { "cell_type": "markdown", "metadata": { "id": "tvgnzT1CKxrO" }, "source": [ "## Overview\n", "\n", "### Gemini\n", "\n", "Gemini is a family of generative AI models developed by Google DeepMind that is designed for multimodal use cases.\n", "\n", "### Controlled Generation\n", "\n", "Depending on your application, you may want the model response to a prompt to be returned in a structured data format, particularly if you are using the responses for downstream processes, such as downstream modules that expect a specific format as input. The Gemini API provides the controlled generation capability to constraint the model output to a structured format.\n", "\n", "This capability is available in the following models:\n", "\n", "- Gemini 2.0\n", "- Gemini 2.0\n", "- Gemini 2.0 Flash\n", "\n", "Learn more about [control generated output](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/control-generated-output).\n", "\n", "\n", "### Objectives\n", "\n", "In this tutorial, you learn how to use the controlled generation capability in the Gemini API in Vertex AI to generate model responses in a structured data format.\n", "\n", "You will complete the following tasks:\n", "\n", "- Sending a prompt with a response schema\n", "- Using controlled generation in use cases requiring output constraints\n" ] }, { "cell_type": "markdown", "metadata": { "id": "61RBz8LLbxCR" }, "source": [ "## Get started" ] }, { "cell_type": "markdown", "metadata": { "id": "No17Cw5hgx12" }, "source": [ "### Install Google Gen AI SDK for Python" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "tFy3H3aPgx12" }, "outputs": [], "source": [ "%pip install --upgrade --user --quiet google-genai" ] }, { "cell_type": "markdown", "metadata": { "id": "R5Xep4W9lq-Z" }, "source": [ "### Restart runtime\n", "\n", "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n", "\n", "The restart might take a minute or longer. After it's restarted, continue to the next step." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XRvKdaPDTznN" }, "outputs": [], "source": [ "import IPython\n", "\n", "app = IPython.Application.instance()\n", "app.kernel.do_shutdown(True)" ] }, { "cell_type": "markdown", "metadata": { "id": "SbmM4z7FOBpM" }, "source": [ "<div class=\"alert alert-block alert-warning\">\n", "<b>⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️</b>\n", "</div>\n" ] }, { "cell_type": "markdown", "metadata": { "id": "dmWOrTJ3gx13" }, "source": [ "### Authenticate your notebook environment (Colab only)\n", "\n", "If you're running this notebook on Google Colab, run the cell below to authenticate your environment." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NyKGtVQjgx13" }, "outputs": [], "source": [ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " from google.colab import auth\n", "\n", " auth.authenticate_user()" ] }, { "cell_type": "markdown", "metadata": { "id": "DF4l8DTdWgPY" }, "source": [ "### Set Google Cloud project information and initialize Vertex AI SDK\n", "\n", "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n", "\n", "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Nqwi-5ufWp_B" }, "outputs": [], "source": [ "import os\n", "\n", "from google import genai\n", "\n", "PROJECT_ID = \"[your-project-id]\" # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n", "\n", "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n", " PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n", "\n", "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n", "\n", "client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)" ] }, { "cell_type": "markdown", "metadata": { "id": "EdvJRUWRNGHE" }, "source": [ "## Code Examples" ] }, { "cell_type": "markdown", "metadata": { "id": "09720c707f1c" }, "source": [ "### Import libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e45ea9a28734" }, "outputs": [], "source": [ "from google.genai.types import GenerateContentConfig, Part, SafetySetting" ] }, { "cell_type": "markdown", "metadata": { "id": "52aeea15a479" }, "source": [ "### Sending a prompt with a response schema\n", "\n", "The Gemini models allow you define a response schema to specify the structure of a model's output, the field names, and the expected data type for each field. The response schema is specified in the `response_schema` parameter in `config`, and the model output will strictly follow that schema.\n", "\n", "You can provide the schemas as [Pydantic](https://docs.pydantic.dev/) models or a [JSON](https://www.json.org/json-en.html) string and the model will respond as JSON or an [Enum](https://docs.python.org/3/library/enum.html) depending on the value set in `response_mime_type`.\n", "\n", "The following examples use Gemini 2.0 Flash (`gemini-2.0-flash`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "81cbb6bd51d8" }, "outputs": [], "source": [ "MODEL_ID = \"gemini-2.0-flash-001\" # @param {type:\"string\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "fadc3bfaf346" }, "source": [ "#### Use a Pydantic object\n", "\n", "Define a response schema for the model output.\n", "\n", "When a model generates a response, it uses the field name and context from your prompt. As such, we recommend that you use a clear structure and unambiguous field names so that your intent is clear.\n", "\n", "If you aren't seeing the results you expect, add more context to your input prompts or revise your response schema. For example, review the model's response without controlled generation to see how the model responds. You can then update your response schema that better fits the model's output." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e4402900f4c6" }, "outputs": [], "source": [ "from pydantic import BaseModel\n", "\n", "\n", "class Recipe(BaseModel):\n", " name: str\n", " description: str\n", " ingredients: list[str]" ] }, { "cell_type": "markdown", "metadata": { "id": "ae54e8bf8dcb" }, "source": [ "When prompting the model to generate the content, pass the schema to the `response_schema` field of the `generation_config`. \n", "\n", "You also need to specify the model output format in the `response_mime_type` field. Output formats such as `application/json` and `text/x.enum` are supported." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "be986a1f342f" }, "outputs": [], "source": [ "response = client.models.generate_content(\n", " model=MODEL_ID,\n", " contents=\"List a few popular cookie recipes and their ingredients.\",\n", " config=GenerateContentConfig(\n", " response_mime_type=\"application/json\",\n", " response_schema=Recipe,\n", " ),\n", ")\n", "\n", "print(response.text)" ] }, { "cell_type": "markdown", "metadata": { "id": "fed9413c2c56" }, "source": [ "You can either parse the response string as JSON, or use the `parsed` field to get the response as an object or dictionary." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cf75d83da42d" }, "outputs": [], "source": [ "cookie_recipe: Recipe = response.parsed\n", "print(cookie_recipe)" ] }, { "cell_type": "markdown", "metadata": { "id": "766346c046f9" }, "source": [ "#### Use an OpenAPI Schema\n", "\n", "Define a response schema for the model output. Use only the supported fields as listed below. All other fields are ignored.\n", "\n", "- `enum`\n", "- `items`\n", "- `maxItems`\n", "- `nullable`\n", "- `properties`\n", "- `required`\n", "\n", "By default, fields are optional, meaning the model can populate the fields or skip them. You can set fields as required to force the model to provide a value. If there's insufficient context in the associated input prompt, the model generates responses mainly based on the data it was trained on." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "af3fa1fbff4f" }, "outputs": [], "source": [ "response_schema = {\n", " \"type\": \"ARRAY\",\n", " \"items\": {\n", " \"type\": \"ARRAY\",\n", " \"items\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"rating\": {\"type\": \"INTEGER\"},\n", " \"flavor\": {\"type\": \"STRING\"},\n", " \"sentiment\": {\n", " \"type\": \"STRING\",\n", " \"enum\": [\"POSITIVE\", \"NEGATIVE\", \"NEUTRAL\"],\n", " },\n", " \"explanation\": {\"type\": \"STRING\"},\n", " },\n", " \"required\": [\"rating\", \"flavor\", \"sentiment\", \"explanation\"],\n", " },\n", " },\n", "}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "5db8b91d5be0" }, "outputs": [], "source": [ "prompt = \"\"\"\n", " Analyze the following product reviews, output the sentiment classification, and give an explanation.\n", "\n", " - \"Absolutely loved it! Best ice cream I've ever had.\" Rating: 4, Flavor: Strawberry Cheesecake\n", " - \"Quite good, but a bit too sweet for my taste.\" Rating: 1, Flavor: Mango Tango\n", "\"\"\"\n", "\n", "response = client.models.generate_content(\n", " model=MODEL_ID,\n", " contents=prompt,\n", " config=GenerateContentConfig(\n", " response_mime_type=\"application/json\",\n", " response_schema=response_schema,\n", " ),\n", ")\n", "product_reviews: dict = response.parsed\n", "print(product_reviews)" ] }, { "cell_type": "markdown", "metadata": { "id": "69450c61bc07" }, "source": [ "### Using controlled generation in use cases requiring output constraints\n", "\n", "Controlled generation can be used to ensure that model outputs adhere to a specific structure (e.g., JSON), instruct the model to perform pure multiple choices (e.g., sentiment classification), or follow certain style or guidelines.\n", "\n", "Let's use controlled generation in the following use cases that require output constraints." ] }, { "cell_type": "markdown", "metadata": { "id": "eba9ef4d4b50" }, "source": [ "#### **Example**: Generate game character profile\n", "\n", "In this example, you instruct the model to create a game character profile with some specific requirements, and constraint the model output to a structured format. This example also demonstrates how to configure the `response_schema` and `response_mime_type` fields in `config` in conjunction with `safety_settings`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "1411f729f2f7" }, "outputs": [], "source": [ "response_schema = {\n", " \"type\": \"ARRAY\",\n", " \"items\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"name\": {\"type\": \"STRING\"},\n", " \"age\": {\"type\": \"INTEGER\"},\n", " \"occupation\": {\"type\": \"STRING\"},\n", " \"background\": {\"type\": \"STRING\"},\n", " \"playable\": {\"type\": \"BOOLEAN\"},\n", " \"children\": {\n", " \"type\": \"ARRAY\",\n", " \"items\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"name\": {\"type\": \"STRING\"},\n", " \"age\": {\"type\": \"INTEGER\"},\n", " },\n", " \"required\": [\"name\", \"age\"],\n", " },\n", " },\n", " },\n", " \"required\": [\"name\", \"age\", \"occupation\", \"children\"],\n", " },\n", "}\n", "\n", "prompt = \"\"\"\n", " Generate a character profile for a video game, including the character's name, age, occupation, background, names of their\n", " three children, and whether they can be controlled by the player.\n", "\"\"\"\n", "\n", "response = client.models.generate_content(\n", " model=MODEL_ID,\n", " contents=prompt,\n", " config=GenerateContentConfig(\n", " response_mime_type=\"application/json\",\n", " response_schema=response_schema,\n", " safety_settings=[\n", " SafetySetting(\n", " category=\"HARM_CATEGORY_DANGEROUS_CONTENT\",\n", " threshold=\"BLOCK_LOW_AND_ABOVE\",\n", " ),\n", " SafetySetting(\n", " category=\"HARM_CATEGORY_HARASSMENT\",\n", " threshold=\"BLOCK_LOW_AND_ABOVE\",\n", " ),\n", " SafetySetting(\n", " category=\"HARM_CATEGORY_HATE_SPEECH\",\n", " threshold=\"BLOCK_LOW_AND_ABOVE\",\n", " ),\n", " SafetySetting(\n", " category=\"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n", " threshold=\"BLOCK_LOW_AND_ABOVE\",\n", " ),\n", " ],\n", " ),\n", ")\n", "character: dict = response.parsed\n", "print(character)" ] }, { "cell_type": "markdown", "metadata": { "id": "e02769d61054" }, "source": [ "#### **Example**: Extract errors from log data\n", "\n", "In this example, you use the model to pull out specific error messages from unstructured log data, extract key information, and constraint the model output to a structured format.\n", "\n", "Some properties are set to nullable so the model can return a null value when it doesn't have enough context to generate a meaningful response.\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "007c0394cadc" }, "outputs": [], "source": [ "response_schema = {\n", " \"type\": \"ARRAY\",\n", " \"items\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"timestamp\": {\"type\": \"STRING\"},\n", " \"error_code\": {\"type\": \"INTEGER\", \"nullable\": True},\n", " \"error_message\": {\"type\": \"STRING\"},\n", " },\n", " \"required\": [\"timestamp\", \"error_message\", \"error_code\"],\n", " },\n", "}\n", "\n", "prompt = \"\"\"\n", "[15:43:28] ERROR: Could not process image upload: Unsupported file format. (Error Code: 308)\n", "[15:44:10] INFO: Search index updated successfully.\n", "[15:45:02] ERROR: Service dependency unavailable (payment gateway). Retrying... (Error Code: 5522)\n", "[15:45:33] ERROR: Application crashed due to out-of-memory exception. (Error Code: 9001)\n", "\"\"\"\n", "\n", "response = client.models.generate_content(\n", " model=MODEL_ID,\n", " contents=prompt,\n", " config=GenerateContentConfig(\n", " response_mime_type=\"application/json\",\n", " response_schema=response_schema,\n", " ),\n", ")\n", "\n", "log_data: dict = response.parsed\n", "print(log_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "a74594893037" }, "source": [ "#### **Example**: Analyze product review data\n", "\n", "In this example, you instruct the model to analyze product review data, extract key entities, perform sentiment classification (multiple choices), provide additional explanation, and output the results in JSON format." ] }, { "cell_type": "markdown", "metadata": { "id": "10971b23afcf" }, "source": [ "#### Example: Detect objects in images\n", "\n", "You can also use controlled generation in multimodality use cases. In this example, you instruct the model to detect objects in the images and output the results in JSON format. These images are stored in a Google Storage bucket.\n", "\n", "- [office-desk.jpeg](https://storage.googleapis.com/cloud-samples-data/generative-ai/image/office-desk.jpeg)\n", "- [gardening-tools.jpeg](https://storage.googleapis.com/cloud-samples-data/generative-ai/image/gardening-tools.jpeg)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "1f3e9935e2da" }, "outputs": [], "source": [ "response_schema = {\n", " \"type\": \"ARRAY\",\n", " \"items\": {\n", " \"type\": \"ARRAY\",\n", " \"items\": {\n", " \"type\": \"OBJECT\",\n", " \"properties\": {\n", " \"object\": {\"type\": \"STRING\"},\n", " },\n", " },\n", " },\n", "}\n", "\n", "prompt = \"Generate a list of objects in the images.\"\n", "\n", "response = client.models.generate_content(\n", " model=MODEL_ID,\n", " contents=[\n", " Part.from_uri(\n", " file_uri=\"gs://cloud-samples-data/generative-ai/image/office-desk.jpeg\",\n", " mime_type=\"image/jpeg\",\n", " ),\n", " Part.from_uri(\n", " file_uri=\"gs://cloud-samples-data/generative-ai/image/gardening-tools.jpeg\",\n", " mime_type=\"image/jpeg\",\n", " ),\n", " prompt,\n", " ],\n", " config=GenerateContentConfig(\n", " response_mime_type=\"application/json\",\n", " response_schema=response_schema,\n", " ),\n", ")\n", "object_list = response.parsed\n", "print(object_list)" ] }, { "cell_type": "markdown", "metadata": { "id": "8e47be074e75" }, "source": [ "#### Example: Respond with a single plain text enum value\n", "\n", "This example identifies the genre of a movie based on its description. The output is one plain-text enum value that the model selects from a list values that are defined in the response schema. Note that in this example, the `response_mime_type` field is set to `text/x.enum`.\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "f4f0a68dda49" }, "outputs": [], "source": [ "response_schema = {\"type\": \"STRING\", \"enum\": [\"drama\", \"comedy\", \"documentary\"]}\n", "\n", "prompt = (\n", " \"The film aims to educate and inform viewers about real-life subjects, events, or people.\"\n", " \"It offers a factual record of a particular topic by combining interviews, historical footage, \"\n", " \"and narration. The primary purpose of a film is to present information and provide insights \"\n", " \"into various aspects of reality.\"\n", ")\n", "\n", "response = client.models.generate_content(\n", " model=MODEL_ID,\n", " contents=prompt,\n", " config=GenerateContentConfig(\n", " response_mime_type=\"text/x.enum\",\n", " response_schema=response_schema,\n", " ),\n", ")\n", "print(response.text)" ] } ], "metadata": { "colab": { "name": "intro_controlled_generation.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }