In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Intro to Vertex AI Multimodal Datasets

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fmultimodal-dataset%2Fintro_vertex_ai_multimodal_dataset.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/bigquery/import?url=https://github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/bigquery/v1/32px.svg" alt="BigQuery Studio logo"><br> Open in BigQuery Studio
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/multimodal-dataset/intro_vertex_ai_multimodal_dataset.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>

| Author |
| --- |
| [Frances Thoma](https://github.com/diskontinuum) |

## Overview

This notebook demonstrates how to use Vertex AI Multimodal Datasets to assemble Gemini requests, to run a validation and resource estimation for supervised fine-tuning, and to create tuning and batch prediction jobs.

### Objectives

- Preview the new Vertex AI Multimodal Datasets SDK
- Demo upcoming integrations

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage
* BigQuery

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and [BigQuery pricing](https://cloud.google.com/bigquery/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

### Prerequisites
1. Make sure that [billing is enabled](https://cloud.google.com/billing/docs/how-to/modify-project) for your project.

2. You must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com). Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

### Questions or Feedback

You can reach out directly to the authors via `vertex-multimodal-dataset-external-feedback@google.com` for feedback or questions.

## Get Started

### Install Vertex AI SDK and other required packages

In [None]:
%pip install --quiet --force-reinstall "numpy<2.0" google-cloud-aiplatform bigframes

### Authenticate your notebook environment (Colab only)

If you are running this notebook on Google Colab, run the cell below to authenticate your environment.

In [None]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

- If you are running this notebook in a local development environment:
  - Install the [Google Cloud SDK](https://cloud.google.com/sdk).
  - Obtain authentication credentials. Create local credentials by running the following command and following the oauth2 flow (read more about the command [here](https://cloud.google.com/sdk/gcloud/reference/beta/auth/application-default/login)):

    ```bash
    gcloud auth application-default login
    ```

### Import libraries

In [None]:
import io
import json

from PIL import Image
import bigframes.pandas as bpd
from google.cloud import storage
from google.cloud.aiplatform.preview import datasets
import google.cloud.bigquery as bq
from google.oauth2 import credentials
import pandas
import vertexai
from vertexai.batch_prediction import BatchPredictionJob
from vertexai.generative_models import Content, Part
from vertexai.preview.tuning import sft

### Set Google Cloud project information

To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [None]:
# Use the environment variable if the user doesn't provide Project ID.
import os

PROJECT_ID = "[your-project-id]"  # @param {type: "string", placeholder: "[your-project-id]", isTemplate: true}
if not PROJECT_ID or PROJECT_ID == "[your-project-id]":
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

vertexai.init(project=PROJECT_ID, location=LOCATION)

# BigFrames settings
bpd.close_session()
bpd.options.bigquery.project = PROJECT_ID
bpd.options.bigquery.location = LOCATION

### Data preparation

The image files and labels used in this tutorial are from the flower dataset used in this [TensorFlow blog post](https://cloud.google.com/blog/products/gcp/how-to-classify-images-with-tensorflow-using-google-cloud-machine-learning-and-cloud-dataflow).

The dataset contains 7338 images, each of which is annotated with one label across 5 different flower classes.

The input images are stored in a public Cloud Storage bucket. This publicly-accessible bucket also contains a CSV file used to create the Vertex AI multimodal dataset. This file has two columns: the first column lists an image's URI in Cloud Storage, and the second column contains the image's label.

In this notebook, we'll use subsets of the flower dataset, each with a fixed number of examples per category, and prepare training, tuning and test subsets
 as DataFrame.

**Tip:** Use the BigFrames library `bpd` instead of `pandas` for larger datasets.

In [None]:
# @title Get Flowers dataset and set up splits
# Get data from GCS
csv = "gs://cloud-samples-data/ai-platform/flowers/flowers.csv"
all_images = pandas.read_csv(csv, names=["image_uris", "labels"])

# Prepare training, validation, and test set
CATEGORIES = ["daisy", "dandelion", "roses", "sunflowers", "tulips"]
TRAINING_CASES_PER_CATEGORY = 100  # @param {type: 'integer'}
VALIDATION_CASES_PER_CATEGORY = 100  # @param {type: 'integer'}
training_set = pandas.DataFrame()
validation_set = pandas.DataFrame()


for category in CATEGORIES:
    same_labels = all_images[all_images["labels"] == category]
    if len(same_labels) < TRAINING_CASES_PER_CATEGORY + VALIDATION_CASES_PER_CATEGORY:
        raise ValueError("Please reduce the number of cases per category.")
    training_set = pandas.concat(
        (training_set, same_labels.iloc[:TRAINING_CASES_PER_CATEGORY]),
        ignore_index=True,
    )
    validation_set = pandas.concat(
        (
            validation_set,
            same_labels.iloc[
                TRAINING_CASES_PER_CATEGORY : TRAINING_CASES_PER_CATEGORY
                + VALIDATION_CASES_PER_CATEGORY
            ],
        ),
        ignore_index=True,
    )

In [None]:
# @title Common Functions

# Set Pandas display options to show all columns and full width for better inspection
pandas.set_option("display.max_columns", None)  # Show all columns
pandas.set_option("display.expand_frame_repr", False)  # Prevent line wrapping
pandas.set_option("display.max_colwidth", None)  # Show full column width

# Dataset inspection helper


def show_dataset_info(dataset):
    print("  Resource name: ", dataset.resource_name)
    print("  Display name: ", dataset.display_name)
    print("  Schema URI:   ", dataset.metadata_schema_uri)
    print("  BQ Table:     ", dataset.bigquery_table)


# Helper is needed as long as tuning integration has not been rolled out yet.
def bq_table_to_jsonl_gcs(*, source_table_id: str, destination_bucket: str) -> str:
    """
    Exports a BigQuery table with a single 'request' column to JSONL
     (values only, no header) in GCS.
    Args:
      source_table_id: The source BigQuery table ID, e.g. `project.dataset.table`.
      destination_bucket: The GCS bucket to export to.
    Returns:
      The GCS URI of the exported JSONL file.
    """
    BQ_CLIENT = bq.Client(project=PROJECT_ID, location=LOCATION)
    bucket_name = destination_bucket.split("/")[2]
    table_name = source_table_id.split(".")[2]
    gcs_file_path = f"temp-{table_name}.jsonl"
    query = f"SELECT request FROM `{source_table_id}`"
    query_job = BQ_CLIENT.query(query)
    results = query_job.result()

    jsonl_data = [
        row.request for row in results
    ]  # Extract only the 'request' column values

    jsonl_string = "\n".join(json.dumps(value) for value in jsonl_data)

    storage_client = storage.Client(project=PROJECT_ID)
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(gcs_file_path)

    blob.upload_from_string(jsonl_string, content_type="application/jsonlines")
    return f"{destination_bucket}/temp-{table_name}.jsonl"


def get_gcs_image(gcs_uri):
    """Download and show an image from Cloud Storage."""
    bearer_token = ! gcloud auth print-access-token
    creds = credentials.Credentials(token=bearer_token[0])
    storage_client = storage.Client(project=PROJECT_ID)
    blob = storage.blob.Blob.from_string(gcs_uri, client=storage_client)
    return Image.open(io.BytesIO(blob.download_as_bytes()))


def construct_gemini_example(
    *, prompt: str = None, response: str = None, system_instructions: str = None
) -> datasets.GeminiExample:
    """Helper method to create a GeminiExample object for single-turn cases.
    Args:

    prompt: User input. Required.
    response: Model response to user input. Optional.
    system_instructions: System instructions for the model. Optional.
    """
    contents = [Content(role="user", parts=[Part.from_text(prompt)])]
    if response:
        contents.append(Content(role="model", parts=[Part.from_text(response)]))
    if system_instructions:
        system_instructions_content = Content(
            parts=[Part.from_text(system_instructions)]
        )
        return datasets.GeminiExample(
            contents=contents, system_instruction=system_instructions_content
        )
    return datasets.GeminiExample(contents=contents)


def construct_template(
    *,
    prompt: str = None,
    response: str = None,
    system_instructions: str = None,
    field_mapping: list[dict[str, str]] = None,
) -> datasets.GeminiTemplateConfig:
    """Helper method to create a GeminiTemplateConfig object for single-turn cases.
    Args:

    prompt: User input. Required.
    response: Model response to user input. Optional.
    system_instructions: System instructions for the model. Optional.
    field_mapping: Mapping of placeholders to dataset columns. Optional.
    """
    gemini_example = construct_gemini_example(
        prompt=prompt, response=response, system_instructions=system_instructions
    )
    return datasets.GeminiTemplateConfig(
        gemini_example=gemini_example, field_mapping=field_mapping
    )

## User Journey Demo

The user journey demonstrated here contains the following steps:

1. Create Dataset
2. Assemble the dataset with a template and inspect assembly
3. Run a validation for tuning
4. Estimate Resources for tuning
5. Run tuning
6. Run batch prediction

### 1. Create a dataset from a Pandas or BigFrames DataFrame

We prepared a DataFrame `training_set` with two columns:

*   `image_uris`: GCS URIs of flower images
*   `labels`: Flower label (five flower categories, one label per image)

In [None]:
flower_uri = training_set["image_uris"].iloc[0]
flower_label = training_set["labels"].iloc[0]

display(get_gcs_image(flower_uri))
print(f"Image URI: {flower_uri}")
print(f"Flower label: {flower_label}")
training_set.head()

Let's create a Vertex AI multimodal dataset from the prepared DataFrame.

In [None]:
flowers = datasets.MultimodalDataset.from_pandas(dataframe=training_set)

show_dataset_info(flowers)

Inspect the new Vertex AI multimodal dataset.
In the near future this method will be available directly via the SDK.

In [None]:
flowers_df = bpd.read_gbq_table(flowers.bigquery_table.strip("bq://"), use_cache=False)
flowers_df.head()

**Other dataset creation options**

Create from a BigQuery table.

```py
my_dataset_from_bigquery = datasets.MultimodalDataset.from_bigquery(
    bigquery_uri=f"bq://projectId.datasetId.tableId"
)
```

Create from a BigFrames DataFrame.

```py
my_dataset_from_pandas = datasets.MultimodalDataset.from_bigframes(
    dataframe=my_dataframe
)
```

Create from a GCS file in JSONL format for assembled input (the JSONL file contains Gemini requests, no assembly required).

```py
my_dataset = datasets.MultimodalDataset.from_gemini_request_jsonl(
    gcs_uri=gcs_uri_of_jsonl_file
)
```

List or load existing datasets.

```py
# Get the most recently created dataset
first_dataset = datasets.MultimodalDataset.list()[0]

# Load dataset based on dataset name
same_dataset = datasets.MultimodalDataset(first_dataset.name)
```

### 2. Assemble the dataset with a template and inspect assembly

To use our Flowers dataset with Gemini, let's assemble a full Gemini request referencing the images in our dataset.

We construct a template configuration by specifying the general prompt, response and system instructions and use placeholders in curly braces. During the assembly, the placeholders are replaced with the values of the dataset column that the placeholders denote.

In [None]:
template_config = construct_template(
    prompt="This is the image: {image_uris}",
    response="{labels}",
    system_instructions="You are a botanical image classifier. Analyze the provided image "
    "and determine the most accurate classification of the flower."
    f"These are the only flower categories: {CATEGORIES}."
    "Return only one category per image.",
)

Here, the template is constructed using the local helper function `construct_template()`. Alternatively, it can be explicitly constructed from a Gemini example as below.

It is also possible to specify a custom field mapping for the placeholders used in the Gemini example. Then the placeholders can have any name, and not necessarily the column name of the dataset column with the values that are being inserted (here image_uris and labels):

```py
gemini_example = datasets.GeminiExample(
    contents=[
        Content(role="user", parts=[Part.from_text("This is the image: {uri}")]),
        Content(role="model", parts=[Part.from_text("{flower}")]),
    ],
    system_instruction=Content(
        parts=[
            Part.from_text(
                "You are a botanical image classifier. Analyze the provided image "
                "and determine the most accurate classification of the flower."
                f"These are the only flower categories: {CATEGORIES}."
                "Return only one category per image."
            )
        ]
    ),
)

template_config = datasets.GeminiTemplateConfig(
    gemini_example=gemini_example,
    field_mapping={"uri_placeholder": "image_uris", "flower_placeholder": "labels"},
)
```

**Assemble and inspect the dataset.**

The dataset assembly creates a BQ table with the assembled examples in a single `request` column. The assembly method below returns a tuple containing a table id (`str`) referencing the assembly BQ table, and a DataFrame (`bigframes.pandas.DataFrame`) for direct inspection.
The DataFrame and the BQ table referenced by the table id contain the assembled dataset in a single column `request`.

In [None]:
table_id, assembly = flowers.assemble(template_config=template_config)

# Inspect assembled dataset
assembly.head()

It is also possible to attach the template and run the assembly without passing it:

```py
my_questions.attach_template_config(template_config=template_config)
_, other_assembly = my_questions.assemble()
```

### 3. Run a validation for tuning

Validate a dataset for tuning.
Tuning dataset usages are: `SFT_VALIDATION`, `SFT_TRAINING`.

First we attach the `template_config` and use it implicitly for all further tasks.

In [None]:
flowers.attach_template_config(template_config=template_config)

validation = flowers.assess_tuning_validity(
    model_name="gemini-2.0-flash-001", dataset_usage="SFT_TRAINING"
)

# Check if there are validation errors
validation.errors

Let's validate a dataset with an incorrect `template_config`, e.g. using a `GeminiExample` that contains two consecutive `user` contents, instead of a `user` content followed by a `model` content.

In [None]:
invalid_gemini_example = datasets.GeminiExample(
    contents=[
        Content(role="user", parts=[Part.from_text("This is the image: {image_uris}")]),
        # Consecutive content turn with the same role
        Content(role="user", parts=[Part.from_text(".")]),
    ],
)
invalid_configuration = datasets.GeminiTemplateConfig(
    gemini_example=invalid_gemini_example
)

validation = flowers.assess_tuning_validity(
    model_name="gemini-2.0-flash-001",
    dataset_usage="SFT_TRAINING",
    template_config=invalid_configuration,
)

validation.errors

### 4. Estimate resources for tuning

In [None]:
tuning_resources = flowers.assess_tuning_resources(model_name="gemini-2.0-flash-001")
print(tuning_resources)

### 5. Run Tuning

In the future we'll provide an integration with BQ directly:

```py
sft_tuning_job = tuning_service.train(
    source_model="gemini-2.0-flash-001",
    train_dataset=flowers,
)
```

For now we use a helper to export the assembly BQ table to a JSONL file on GCS and provide the GCS URI as training dataset reference. Please provide a Google Cloud Storage bucket for the export.

In [None]:
# The following will be removed once the tuning integration has been completed.
# Set a GCS bucket for exporting your dataset.
tuning_destination_bucket = "gs://my-tuning-export-bucket"  # @param {type:"string"}

In [None]:
# Assemble Gemini request
assembly_table_id, _ = flowers.assemble()
# Export assembly as JSONL to GCS bucket
train_gcs_uri = bq_table_to_jsonl_gcs(
    source_table_id=assembly_table_id, destination_bucket=tuning_destination_bucket
)

tuning_job = sft.train(
    source_model="gemini-2.0-flash-001",
    train_dataset=train_gcs_uri,
)

Let's also prepare and use the validation dataset.

In [None]:
# Create a Vertex AI Multimodal dataset for the validation set
flowers_validation_dataset = datasets.MultimodalDataset.from_pandas(
    dataframe=validation_set
)

# Assemble Gemini request
assembly_table_id, _ = flowers_validation_dataset.assemble(
    template_config=template_config
)
# Export assembly as JSONL to GCS bucket
validation_gcs_uri = bq_table_to_jsonl_gcs(
    source_table_id=assembly_table_id, destination_bucket=tuning_destination_bucket
)

# Run tuning job with train and validation dataset
tuning_job = sft.train(
    source_model="gemini-2.0-flash-001",
    train_dataset=train_gcs_uri,
    validation_dataset=validation_gcs_uri,
)

### 6. Batch Prediction

In the future we'll provide an integration with Batch Prediction directly:

```py
batch_prediction_job = BatchPredictionJob.submit(
    source_model="gemini-2.0-flash-001",
    input_dataset=flowers,
    output_uri_prefix=output_uri,
)
```

For now we provide the assembly BQ URI as input dataset:

In [None]:
batch_prediction_job = BatchPredictionJob.submit(
    source_model="gemini-2.0-flash-001",
    input_dataset=f"bq://{assembly_table_id}",
)