multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb (945 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_TI7YxcMPkTi"
},
"outputs": [],
"source": [
"# Copyright 2025 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ysLSgHUz9JAN"
},
"source": [
"# Intro to Vertex AI Multimodal Datasets\n",
"\n",
"<table align=\"left\">\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb\">\n",
" <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fmultimodal-dataset%2Fintro_vertex_ai_multimodal_dataset.ipynb\">\n",
" <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb\">\n",
" <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/bigquery/import?url=https://github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb\">\n",
" <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/bigquery/v1/32px.svg\" alt=\"BigQuery Studio logo\"><br> Open in BigQuery Studio\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb\">\n",
" <img width=\"32px\" src=\"https://www.svgrepo.com/download/217753/github.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
" </a>\n",
" </td>\n",
"</table>\n",
"\n",
"<div style=\"clear: both;\"></div>\n",
"\n",
"<b>Share to:</b>\n",
"\n",
"<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
"</a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AIGJwpvY3ELC"
},
"source": [
"| Author |\n",
"| --- |\n",
"| [Frances Thoma](https://github.com/diskontinuum) |"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7veA0yodoHzx"
},
"source": [
"## Overview\n",
"\n",
"This notebook demonstrates how to use Vertex AI Multimodal Datasets to assemble Gemini requests, to run a validation and resource estimation for supervised fine-tuning, and to create tuning and batch prediction jobs.\n",
"\n",
"### Objectives\n",
"\n",
"- Preview the new Vertex AI Multimodal Datasets SDK\n",
"- Demo upcoming integrations\n",
"\n",
"### Costs\n",
"\n",
"This tutorial uses billable components of Google Cloud:\n",
"\n",
"* Vertex AI\n",
"* Cloud Storage\n",
"* BigQuery\n",
"\n",
"Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and [BigQuery pricing](https://cloud.google.com/bigquery/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.\n",
"\n",
"### Prerequisites\n",
"1. Make sure that [billing is enabled](https://cloud.google.com/billing/docs/how-to/modify-project) for your project.\n",
"\n",
"2. You must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com). Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).\n",
"\n",
"### Questions or Feedback\n",
"\n",
"You can reach out directly to the authors via `vertex-multimodal-dataset-external-feedback@google.com` for feedback or questions."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t-49KwqiQiaw"
},
"source": [
"## Get Started"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U3fA_vcl3Y8s"
},
"source": [
"### Install Vertex AI SDK and other required packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Md24tB4pYTSY"
},
"outputs": [],
"source": [
"%pip install --quiet --force-reinstall \"numpy<2.0\" google-cloud-aiplatform bigframes"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DbqCDqs5QlGh"
},
"source": [
"### Authenticate your notebook environment (Colab only)\n",
"\n",
"If you are running this notebook on Google Colab, run the cell below to authenticate your environment."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gx-AV_VkafnO"
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
" from google.colab import auth\n",
"\n",
" auth.authenticate_user()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dllbnne2al7X"
},
"source": [
"- If you are running this notebook in a local development environment:\n",
" - Install the [Google Cloud SDK](https://cloud.google.com/sdk).\n",
" - Obtain authentication credentials. Create local credentials by running the following command and following the oauth2 flow (read more about the command [here](https://cloud.google.com/sdk/gcloud/reference/beta/auth/application-default/login)):\n",
"\n",
" ```bash\n",
" gcloud auth application-default login\n",
" ```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-xkKPNED9yPa"
},
"source": [
"### Import libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mp_q41d8RgAr"
},
"outputs": [],
"source": [
"import io\n",
"import json\n",
"\n",
"from PIL import Image\n",
"import bigframes.pandas as bpd\n",
"from google.cloud import storage\n",
"from google.cloud.aiplatform.preview import datasets\n",
"import google.cloud.bigquery as bq\n",
"from google.oauth2 import credentials\n",
"import pandas\n",
"import vertexai\n",
"from vertexai.batch_prediction import BatchPredictionJob\n",
"from vertexai.generative_models import Content, Part\n",
"from vertexai.preview.tuning import sft"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s9OCHaQsliNJ"
},
"source": [
"### Set Google Cloud project information\n",
"\n",
"To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
"\n",
"Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HptYirwARjXW"
},
"outputs": [],
"source": [
"# Use the environment variable if the user doesn't provide Project ID.\n",
"import os\n",
"\n",
"PROJECT_ID = \"[your-project-id]\" # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n",
"if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n",
" PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
"\n",
"LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
"\n",
"vertexai.init(project=PROJECT_ID, location=LOCATION)\n",
"\n",
"# BigFrames settings\n",
"bpd.close_session()\n",
"bpd.options.bigquery.project = PROJECT_ID\n",
"bpd.options.bigquery.location = LOCATION"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "72tXYePAWDGJ"
},
"source": [
"### Data preparation\n",
"\n",
"The image files and labels used in this tutorial are from the flower dataset used in this [TensorFlow blog post](https://cloud.google.com/blog/products/gcp/how-to-classify-images-with-tensorflow-using-google-cloud-machine-learning-and-cloud-dataflow).\n",
"\n",
"The dataset contains 7338 images, each of which is annotated with one label across 5 different flower classes.\n",
"\n",
"The input images are stored in a public Cloud Storage bucket. This publicly-accessible bucket also contains a CSV file used to create the Vertex AI multimodal dataset. This file has two columns: the first column lists an image's URI in Cloud Storage, and the second column contains the image's label.\n",
"\n",
"In this notebook, we'll use subsets of the flower dataset, each with a fixed number of examples per category, and prepare training, tuning and test subsets\n",
" as DataFrame.\n",
"\n",
"**Tip:** Use the BigFrames library `bpd` instead of `pandas` for larger datasets."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "VVl2GGM2RJrU"
},
"outputs": [],
"source": [
"# @title Get Flowers dataset and set up splits\n",
"# Get data from GCS\n",
"csv = \"gs://cloud-samples-data/ai-platform/flowers/flowers.csv\"\n",
"all_images = pandas.read_csv(csv, names=[\"image_uris\", \"labels\"])\n",
"\n",
"# Prepare training, validation, and test set\n",
"CATEGORIES = [\"daisy\", \"dandelion\", \"roses\", \"sunflowers\", \"tulips\"]\n",
"TRAINING_CASES_PER_CATEGORY = 100 # @param {type: 'integer'}\n",
"VALIDATION_CASES_PER_CATEGORY = 100 # @param {type: 'integer'}\n",
"training_set = pandas.DataFrame()\n",
"validation_set = pandas.DataFrame()\n",
"\n",
"\n",
"for category in CATEGORIES:\n",
" same_labels = all_images[all_images[\"labels\"] == category]\n",
" if len(same_labels) < TRAINING_CASES_PER_CATEGORY + VALIDATION_CASES_PER_CATEGORY:\n",
" raise ValueError(\"Please reduce the number of cases per category.\")\n",
" training_set = pandas.concat(\n",
" (training_set, same_labels.iloc[:TRAINING_CASES_PER_CATEGORY]),\n",
" ignore_index=True,\n",
" )\n",
" validation_set = pandas.concat(\n",
" (\n",
" validation_set,\n",
" same_labels.iloc[\n",
" TRAINING_CASES_PER_CATEGORY : TRAINING_CASES_PER_CATEGORY\n",
" + VALIDATION_CASES_PER_CATEGORY\n",
" ],\n",
" ),\n",
" ignore_index=True,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "MYYGPsA69g_-"
},
"outputs": [],
"source": [
"# @title Common Functions\n",
"\n",
"# Set Pandas display options to show all columns and full width for better inspection\n",
"pandas.set_option(\"display.max_columns\", None) # Show all columns\n",
"pandas.set_option(\"display.expand_frame_repr\", False) # Prevent line wrapping\n",
"pandas.set_option(\"display.max_colwidth\", None) # Show full column width\n",
"\n",
"# Dataset inspection helper\n",
"\n",
"\n",
"def show_dataset_info(dataset):\n",
" print(\" Resource name: \", dataset.resource_name)\n",
" print(\" Display name: \", dataset.display_name)\n",
" print(\" Schema URI: \", dataset.metadata_schema_uri)\n",
" print(\" BQ Table: \", dataset.bigquery_table)\n",
"\n",
"\n",
"# Helper is needed as long as tuning integration has not been rolled out yet.\n",
"def bq_table_to_jsonl_gcs(*, source_table_id: str, destination_bucket: str) -> str:\n",
" \"\"\"\n",
" Exports a BigQuery table with a single 'request' column to JSONL\n",
" (values only, no header) in GCS.\n",
" Args:\n",
" source_table_id: The source BigQuery table ID, e.g. `project.dataset.table`.\n",
" destination_bucket: The GCS bucket to export to.\n",
" Returns:\n",
" The GCS URI of the exported JSONL file.\n",
" \"\"\"\n",
" BQ_CLIENT = bq.Client(project=PROJECT_ID, location=LOCATION)\n",
" bucket_name = destination_bucket.split(\"/\")[2]\n",
" table_name = source_table_id.split(\".\")[2]\n",
" gcs_file_path = f\"temp-{table_name}.jsonl\"\n",
" query = f\"SELECT request FROM `{source_table_id}`\"\n",
" query_job = BQ_CLIENT.query(query)\n",
" results = query_job.result()\n",
"\n",
" jsonl_data = [\n",
" row.request for row in results\n",
" ] # Extract only the 'request' column values\n",
"\n",
" jsonl_string = \"\\n\".join(json.dumps(value) for value in jsonl_data)\n",
"\n",
" storage_client = storage.Client(project=PROJECT_ID)\n",
" bucket = storage_client.bucket(bucket_name)\n",
" blob = bucket.blob(gcs_file_path)\n",
"\n",
" blob.upload_from_string(jsonl_string, content_type=\"application/jsonlines\")\n",
" return f\"{destination_bucket}/temp-{table_name}.jsonl\"\n",
"\n",
"\n",
"def get_gcs_image(gcs_uri):\n",
" \"\"\"Download and show an image from Cloud Storage.\"\"\"\n",
" bearer_token = ! gcloud auth print-access-token\n",
" creds = credentials.Credentials(token=bearer_token[0])\n",
" storage_client = storage.Client(project=PROJECT_ID)\n",
" blob = storage.blob.Blob.from_string(gcs_uri, client=storage_client)\n",
" return Image.open(io.BytesIO(blob.download_as_bytes()))\n",
"\n",
"\n",
"def construct_gemini_example(\n",
" *, prompt: str = None, response: str = None, system_instructions: str = None\n",
") -> datasets.GeminiExample:\n",
" \"\"\"Helper method to create a GeminiExample object for single-turn cases.\n",
" Args:\n",
"\n",
" prompt: User input. Required.\n",
" response: Model response to user input. Optional.\n",
" system_instructions: System instructions for the model. Optional.\n",
" \"\"\"\n",
" contents = [Content(role=\"user\", parts=[Part.from_text(prompt)])]\n",
" if response:\n",
" contents.append(Content(role=\"model\", parts=[Part.from_text(response)]))\n",
" if system_instructions:\n",
" system_instructions_content = Content(\n",
" parts=[Part.from_text(system_instructions)]\n",
" )\n",
" return datasets.GeminiExample(\n",
" contents=contents, system_instruction=system_instructions_content\n",
" )\n",
" return datasets.GeminiExample(contents=contents)\n",
"\n",
"\n",
"def construct_template(\n",
" *,\n",
" prompt: str = None,\n",
" response: str = None,\n",
" system_instructions: str = None,\n",
" field_mapping: list[dict[str, str]] = None,\n",
") -> datasets.GeminiTemplateConfig:\n",
" \"\"\"Helper method to create a GeminiTemplateConfig object for single-turn cases.\n",
" Args:\n",
"\n",
" prompt: User input. Required.\n",
" response: Model response to user input. Optional.\n",
" system_instructions: System instructions for the model. Optional.\n",
" field_mapping: Mapping of placeholders to dataset columns. Optional.\n",
" \"\"\"\n",
" gemini_example = construct_gemini_example(\n",
" prompt=prompt, response=response, system_instructions=system_instructions\n",
" )\n",
" return datasets.GeminiTemplateConfig(\n",
" gemini_example=gemini_example, field_mapping=field_mapping\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aKw7PCHpXwgn"
},
"source": [
"## User Journey Demo\n",
"\n",
"The user journey demonstrated here contains the following steps:\n",
"\n",
"1. Create Dataset\n",
"2. Assemble the dataset with a template and inspect assembly\n",
"3. Run a validation for tuning\n",
"4. Estimate Resources for tuning\n",
"5. Run tuning\n",
"6. Run batch prediction"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HybHbfY85MnR"
},
"source": [
"### 1. Create a dataset from a Pandas or BigFrames DataFrame"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fqewmHHY7TYO"
},
"source": [
"We prepared a DataFrame `training_set` with two columns:\n",
"\n",
"* `image_uris`: GCS URIs of flower images\n",
"* `labels`: Flower label (five flower categories, one label per image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8DxI6amF8ok1"
},
"outputs": [],
"source": [
"flower_uri = training_set[\"image_uris\"].iloc[0]\n",
"flower_label = training_set[\"labels\"].iloc[0]\n",
"\n",
"display(get_gcs_image(flower_uri))\n",
"print(f\"Image URI: {flower_uri}\")\n",
"print(f\"Flower label: {flower_label}\")\n",
"training_set.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jSxMKdBiId8r"
},
"source": [
"Let's create a Vertex AI multimodal dataset from the prepared DataFrame."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xc37pkL1SnxL"
},
"outputs": [],
"source": [
"flowers = datasets.MultimodalDataset.from_pandas(dataframe=training_set)\n",
"\n",
"show_dataset_info(flowers)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Hjybx66Angq"
},
"source": [
"Inspect the new Vertex AI multimodal dataset.\n",
"In the near future this method will be available directly via the SDK."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9cmf9kNiAeTr"
},
"outputs": [],
"source": [
"flowers_df = bpd.read_gbq_table(flowers.bigquery_table.strip(\"bq://\"), use_cache=False)\n",
"flowers_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1FfwiGkTgp6F"
},
"source": [
"**Other dataset creation options**\n",
"\n",
"Create from a BigQuery table.\n",
"\n",
"```py\n",
"my_dataset_from_bigquery = datasets.MultimodalDataset.from_bigquery(\n",
" bigquery_uri=f\"bq://projectId.datasetId.tableId\"\n",
")\n",
"```\n",
"\n",
"Create from a BigFrames DataFrame.\n",
"\n",
"```py\n",
"my_dataset_from_pandas = datasets.MultimodalDataset.from_bigframes(\n",
" dataframe=my_dataframe\n",
")\n",
"```\n",
"\n",
"Create from a GCS file in JSONL format for assembled input (the JSONL file contains Gemini requests, no assembly required).\n",
"\n",
"```py\n",
"my_dataset = datasets.MultimodalDataset.from_gemini_request_jsonl(\n",
" gcs_uri=gcs_uri_of_jsonl_file\n",
")\n",
"```\n",
"\n",
"List or load existing datasets.\n",
"\n",
"```py\n",
"# Get the most recently created dataset\n",
"first_dataset = datasets.MultimodalDataset.list()[0]\n",
"\n",
"# Load dataset based on dataset name\n",
"same_dataset = datasets.MultimodalDataset(first_dataset.name)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vkD3LrKlgkaS"
},
"source": [
"### 2. Assemble the dataset with a template and inspect assembly\n",
"\n",
"To use our Flowers dataset with Gemini, let's assemble a full Gemini request referencing the images in our dataset."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PB5VRnpVIEGx"
},
"source": [
"We construct a template configuration by specifying the general prompt, response and system instructions and use placeholders in curly braces. During the assembly, the placeholders are replaced with the values of the dataset column that the placeholders denote."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gIJUKU6_qV99"
},
"outputs": [],
"source": [
"template_config = construct_template(\n",
" prompt=\"This is the image: {image_uris}\",\n",
" response=\"{labels}\",\n",
" system_instructions=\"You are a botanical image classifier. Analyze the provided image \"\n",
" \"and determine the most accurate classification of the flower.\"\n",
" f\"These are the only flower categories: {CATEGORIES}.\"\n",
" \"Return only one category per image.\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nD20fEieM4rk"
},
"source": [
"Here, the template is constructed using the local helper function `construct_template()`. Alternatively, it can be explicitly constructed from a Gemini example as below.\n",
"\n",
"It is also possible to specify a custom field mapping for the placeholders used in the Gemini example. Then the placeholders can have any name, and not necessarily the column name of the dataset column with the values that are being inserted (here image_uris and labels):\n",
"\n",
"```py\n",
"gemini_example = datasets.GeminiExample(\n",
" contents=[\n",
" Content(role=\"user\", parts=[Part.from_text(\"This is the image: {uri}\")]),\n",
" Content(role=\"model\", parts=[Part.from_text(\"{flower}\")]),\n",
" ],\n",
" system_instruction=Content(\n",
" parts=[\n",
" Part.from_text(\n",
" \"You are a botanical image classifier. Analyze the provided image \"\n",
" \"and determine the most accurate classification of the flower.\"\n",
" f\"These are the only flower categories: {CATEGORIES}.\"\n",
" \"Return only one category per image.\"\n",
" )\n",
" ]\n",
" ),\n",
")\n",
"\n",
"template_config = datasets.GeminiTemplateConfig(\n",
" gemini_example=gemini_example,\n",
" field_mapping={\"uri_placeholder\": \"image_uris\", \"flower_placeholder\": \"labels\"},\n",
")\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HQnaPYbXPtdI"
},
"source": [
"**Assemble and inspect the dataset.**\n",
"\n",
"The dataset assembly creates a BQ table with the assembled examples in a single `request` column. The assembly method below returns a tuple containing a table id (`str`) referencing the assembly BQ table, and a DataFrame (`bigframes.pandas.DataFrame`) for direct inspection.\n",
"The DataFrame and the BQ table referenced by the table id contain the assembled dataset in a single column `request`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A5zGxYOYQSJE"
},
"outputs": [],
"source": [
"table_id, assembly = flowers.assemble(template_config=template_config)\n",
"\n",
"# Inspect assembled dataset\n",
"assembly.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S0KeOWRvSj75"
},
"source": [
"It is also possible to attach the template and run the assembly without passing it:\n",
"\n",
"```py\n",
"my_questions.attach_template_config(template_config=template_config)\n",
"_, other_assembly = my_questions.assemble()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Lk8zIbaQ55Bp"
},
"source": [
"### 3. Run a validation for tuning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m7UODTevZovT"
},
"source": [
"Validate a dataset for tuning.\n",
"Tuning dataset usages are: `SFT_VALIDATION`, `SFT_TRAINING`.\n",
"\n",
"First we attach the `template_config` and use it implicitly for all further tasks."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_4itps2hkqVS"
},
"outputs": [],
"source": [
"flowers.attach_template_config(template_config=template_config)\n",
"\n",
"validation = flowers.assess_tuning_validity(\n",
" model_name=\"gemini-2.0-flash-001\", dataset_usage=\"SFT_TRAINING\"\n",
")\n",
"\n",
"# Check if there are validation errors\n",
"validation.errors"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tt1n6Mwak864"
},
"source": [
"Let's validate a dataset with an incorrect `template_config`, e.g. using a `GeminiExample` that contains two consecutive `user` contents, instead of a `user` content followed by a `model` content."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yscl1KH8XUTy"
},
"outputs": [],
"source": [
"invalid_gemini_example = datasets.GeminiExample(\n",
" contents=[\n",
" Content(role=\"user\", parts=[Part.from_text(\"This is the image: {image_uris}\")]),\n",
" # Consecutive content turn with the same role\n",
" Content(role=\"user\", parts=[Part.from_text(\".\")]),\n",
" ],\n",
")\n",
"invalid_configuration = datasets.GeminiTemplateConfig(\n",
" gemini_example=invalid_gemini_example\n",
")\n",
"\n",
"validation = flowers.assess_tuning_validity(\n",
" model_name=\"gemini-2.0-flash-001\",\n",
" dataset_usage=\"SFT_TRAINING\",\n",
" template_config=invalid_configuration,\n",
")\n",
"\n",
"validation.errors"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ayqG3aKX55P8"
},
"source": [
"### 4. Estimate resources for tuning"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "s8ufb8KWYE51"
},
"outputs": [],
"source": [
"tuning_resources = flowers.assess_tuning_resources(model_name=\"gemini-2.0-flash-001\")\n",
"print(tuning_resources)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i5wAETJ055XQ"
},
"source": [
"### 5. Run Tuning\n",
"\n",
"In the future we'll provide an integration with BQ directly:\n",
"\n",
"```py\n",
"sft_tuning_job = tuning_service.train(\n",
" source_model=\"gemini-2.0-flash-001\",\n",
" train_dataset=flowers,\n",
")\n",
"```\n",
"\n",
"For now we use a helper to export the assembly BQ table to a JSONL file on GCS and provide the GCS URI as training dataset reference. Please provide a Google Cloud Storage bucket for the export."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XIzeW5Qw36Si"
},
"outputs": [],
"source": [
"# The following will be removed once the tuning integration has been completed.\n",
"# Set a GCS bucket for exporting your dataset.\n",
"tuning_destination_bucket = \"gs://my-tuning-export-bucket\" # @param {type:\"string\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M62UKX-3byET"
},
"outputs": [],
"source": [
"# Assemble Gemini request\n",
"assembly_table_id, _ = flowers.assemble()\n",
"# Export assembly as JSONL to GCS bucket\n",
"train_gcs_uri = bq_table_to_jsonl_gcs(\n",
" source_table_id=assembly_table_id, destination_bucket=tuning_destination_bucket\n",
")\n",
"\n",
"tuning_job = sft.train(\n",
" source_model=\"gemini-2.0-flash-001\",\n",
" train_dataset=train_gcs_uri,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y41fcy6EII5C"
},
"source": [
"Let's also prepare and use the validation dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rZiHdiB4IGyS"
},
"outputs": [],
"source": [
"# Create a Vertex AI Multimodal dataset for the validation set\n",
"flowers_validation_dataset = datasets.MultimodalDataset.from_pandas(\n",
" dataframe=validation_set\n",
")\n",
"\n",
"# Assemble Gemini request\n",
"assembly_table_id, _ = flowers_validation_dataset.assemble(\n",
" template_config=template_config\n",
")\n",
"# Export assembly as JSONL to GCS bucket\n",
"validation_gcs_uri = bq_table_to_jsonl_gcs(\n",
" source_table_id=assembly_table_id, destination_bucket=tuning_destination_bucket\n",
")\n",
"\n",
"# Run tuning job with train and validation dataset\n",
"tuning_job = sft.train(\n",
" source_model=\"gemini-2.0-flash-001\",\n",
" train_dataset=train_gcs_uri,\n",
" validation_dataset=validation_gcs_uri,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b0JLlzaAtoeX"
},
"source": [
"### 6. Batch Prediction\n",
"\n",
"In the future we'll provide an integration with Batch Prediction directly:\n",
"\n",
"```py\n",
"batch_prediction_job = BatchPredictionJob.submit(\n",
" source_model=\"gemini-2.0-flash-001\",\n",
" input_dataset=flowers,\n",
" output_uri_prefix=output_uri,\n",
")\n",
"```\n",
"\n",
"For now we provide the assembly BQ URI as input dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qOUNjb1AdzHo"
},
"outputs": [],
"source": [
"batch_prediction_job = BatchPredictionJob.submit(\n",
" source_model=\"gemini-2.0-flash-001\",\n",
" input_dataset=f\"bq://{assembly_table_id}\",\n",
")"
]
}
],
"metadata": {
"colab": {
"name": "intro_vertex_ai_multimodal_dataset.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}