In [None]:
# @title Copyright & License (click to expand)
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Supervised Fine Tuning with Gemini 2.0 Flash for Image Captioning

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Ftuning%2Fsft_gemini_on_image_captioning.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>    
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>            

| | |
|-|-|
|Author(s) | [Deepak Moonat](https://github.com/dmoonat) |

## Overview

**Gemini** is a family of generative AI models developed by Google DeepMind that is designed for multimodal use cases. The Gemini API gives you access to the various Gemini models, such as Gemini 2.0 Pro/Flash, Gemini 2.0/Flash, Gemini/Flash and more.

This notebook demonstrates how to fine-tune the Gemini 2.0 Flash generative model using the Vertex AI Supervised Tuning feature. Supervised Tuning allows you to use your own training data to further refine the base model's capabilities towards your specific tasks.


Supervised Tuning uses labeled examples to tune a model. Each example demonstrates the output you want from your text model during inference.

First, ensure your training data is of high quality, well-labeled, and directly relevant to the target task. This is crucial as low-quality data can adversely affect the performance and introduce bias in the fine-tuned model.
- Training: Experiment with different configurations to optimize the model's performance on the target task.
- Evaluation:
  - Metric: Choose appropriate evaluation metrics that accurately reflect the success of the fine-tuned model for your specific task
  - Evaluation Set: Use a separate set of data to evaluate the model's performance

### Objective

In this tutorial, you will learn how to use `Vertex AI` to tune a `Gemini 2.0 Flash` model.


This tutorial uses the following Google Cloud ML services:

- `Vertex AI`


The steps performed include:

- Prepare and load the dataset
- Load the `gemini-2.0-flash-001` model
- Evaluate the model before tuning
- Tune the model.
  - This will automatically create a Vertex AI endpoint and deploy the model to it
- Evaluate the model after tuning
- Make a prediction using tuned model.

### Dataset

Dataset used in this notebook is about image captioning. [Reference](https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#download_the_model_checkpoint)

```
Licensed under the Creative Commons Attribution 4.0 License
```

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI
pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage
pricing](https://cloud.google.com/storage/pricing), and use the [Pricing
Calculator](https://cloud.google.com/products/calculator/)
to generate a cost estimate based on your projected usage.

### Install Gen AI SDK and other required packages

The new Google Gen AI SDK provides a unified interface to Gemini through both the Gemini Developer API and the Gemini API on Vertex AI. With a few exceptions, code that runs on one platform will run on both. This means that you can prototype an application using the Developer API and then migrate the application to Vertex AI without rewriting your code.

In [None]:
%pip install --upgrade --user --quiet google-genai google-cloud-aiplatform jsonlines rouge_score

### Restart runtime (Colab only)

To use the newly installed packages, you must restart the runtime on Google Colab.

In [None]:
# Automatically restart kernel after installs so that your environment can access the new packages
import sys

if "google.colab" in sys.modules:
    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️</b>
</div>

## Before you begin

### Set your project ID

**If you don't know your project ID**, try the following:
* Run `gcloud config list`.
* Run `gcloud projects list`.
* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)

In [None]:
PROJECT_ID = "[YOUR_PROJECT_ID]"  # @param {type:"string"}
# Set the project id
! gcloud config set project {PROJECT_ID}

#### Region

You can also change the `REGION` variable used by Vertex AI. Learn more about [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations).

In [None]:
REGION = "us-central1"  # @param {type:"string"}

#### Bucket


In [None]:
BUCKET_NAME = "[YOUR_BUCKET_NAME]"  # @param {type:"string"}
BUCKET_URI = f"gs://{BUCKET_NAME}"

### Authenticate your Google Cloud account

Depending on your Jupyter environment, you may have to manually authenticate. Follow the relevant instructions below.

**1. Vertex AI Workbench**
* Do nothing as you are already authenticated.

**2. Local JupyterLab instance, uncomment and run:**

In [None]:
# ! gcloud auth login

**3. Authenticate your notebook environment**

If you are running this notebook on Google Colab, run the cell below to authenticate your environment.

In [None]:
from google.colab import auth

auth.authenticate_user()

**4. Service account or other**
* See how to grant Cloud Storage permissions to your service account at https://cloud.google.com/storage/docs/gsutil/commands/iam#ch-examples.

### Import libraries

In [None]:
import io
import time

# For visualization.
from PIL import Image
from google import genai

# For Google Cloud Storage service.
from google.cloud import storage

# For fine tuning Gemini model.
import google.cloud.aiplatform as aiplatform
from google.genai import types

# For data handling.
import jsonlines
import pandas as pd

# For evaluation.
from rouge_score import rouge_scorer
from tqdm import tqdm

## Initialize Vertex AI and Gen AI SDK for python

In [None]:
aiplatform.init(project=PROJECT_ID, location=REGION)

client = genai.Client(vertexai=True, project=PROJECT_ID, location=REGION)

## Prepare Multimodal Dataset

The dataset used to tune a foundation model needs to include examples that align with the task that you want the model to perform.

Note:
- Only support images and text as input, and text only as output.
- Maximum 30 Images per tuning example.
- Maximum image file size: 20MB
- Image needs to be in `jpeg` or `png` format. Supported mimetypes: `image/jpeg` and `image/png`

Input is a jsonl file with each json string being on one line.
Each json instance have the format (Expanded for clarity):
```
{
   "contents":[
      {
         "role":"user",  # This indicate input content
         "parts":[ # Interleaved image and text, could be in any order.
            {
               "fileData":{ # FileData needs to be reference to image file in gcs. No inline data.
                  "mimeType":"image/jpeg", # Provide the mimeType about this image
                  "fileUri":"gs://path/to/image_uri"
               }
            }
            {
               "text":"What is in this image?"
            }
         ]
      },
      {
         "role":"model", # This indicate target content
         "parts":[ # text only
            {
               "text":"Something about this image."
            }
         ]
      } # Single turn input and response.
   ]
}
```

Example:
```
{
   "contents":[
      {
         "role":"user",
         "parts":[
            {
               "fileData":{
                  "mimeType":"image/jpeg",
                  "fileUri":"gs://bucketname/data/vision_data/task/image_description/image/1.jpeg"
               }
            },
            {
               "text":"Describe this image that captures the essence of it."
            }
         ]
      },
      {
         "role":"model",
         "parts":[
            {
               "text":"A person wearing a pink shirt and a long-sleeved shirt with a large cuff, ...."
            }
         ]
      }
   ]
}
```


### Data files


Data used in this notebook is present in the public Cloud Storage(GCS) bucket, `gs://longcap100`.

Sample:

> {"prefix": "", "suffix": "A person wearing a pink shirt and a long-sleeved shirt with a large cuff, has their hand on a concrete ledge. The hand is on the edge of the ledge, and the thumb is on the edge of the hand. The shirt has a large cuff, and the sleeve is rolled up. The shadow of the hand is on the wall.", "image": "91.jpeg"}



- `data_train90.jsonl`: Contains training samples in json lines as shown above
- `data_val10.jsonl`: Contains validation samples in json lines as shown above
- `images`: Contains 100 images, training and validation data

To run a tuning job, you need to upload one or more datasets to a Cloud Storage bucket. You can either create a new Cloud Storage bucket or use an existing one to store dataset files. The region of the bucket doesn't matter, but we recommend that you use a bucket that's in the same Google Cloud project where you plan to tune your model.

### Create a Cloud Storage bucket

- Create a storage bucket to store intermediate artifacts such as datasets.

- Only if your bucket doesn't already exist: Run the following cell to create your Cloud Storage bucket.


In [None]:
!gsutil mb -l {REGION} -p {PROJECT_ID} {BUCKET_URI}

### Copy images to specified Bucket

In [None]:
!gsutil -m -q cp -n -r gs://longcap100/*.jpeg {BUCKET_URI}/images/

- Download the training and validation dataset jsonlines files from the bucket.

In [None]:
!gsutil -m -q cp -n -r gs://longcap100/data_train90.jsonl .

In [None]:
!gsutil -m -q cp -n -r gs://longcap100/data_val10.jsonl .

### Prepare dataset for Training and Evaluation

- Utility function to save json instances into jsonlines format

In [None]:
def save_jsonlines(file, instances):
    """
    Saves a list of json instances to a jsonlines file.
    """
    with jsonlines.open(file, mode="w") as writer:
        writer.write_all(instances)

- Below function converts the dataset into Gemini-1.5 tuning format

In [None]:
task_prompt = "Describe this image in detail that captures the essence of it."

In [None]:
def create_tuning_samples(file_path):
    """
    Creates tuning samples from a file.
    """
    with jsonlines.open(file_path) as reader:
        instances = []
        for obj in reader:
            instance = {
                "contents": [
                    {
                        "role": "user",  # This indicate input content
                        "parts": [  # Interleaved image and text, could be in any order.
                            {
                                "fileData": {  # FileData needs to be reference to image file in gcs. No inline data.
                                    "mimeType": "image/jpeg",  # Provide the mimeType about this image
                                    "fileUri": f"{BUCKET_URI}/images/{obj['image']}",
                                }
                            },
                            {"text": task_prompt},
                        ],
                    },
                    {
                        "role": "model",  # This indicate target content
                        "parts": [{"text": obj["suffix"]}],  # text only
                    },  # Single turn input and response.
                ]
            }
            instances.append(instance)
    return instances

- Training data

In [None]:
train_file_path = "data_train90.jsonl"
train_instances = create_tuning_samples(train_file_path)
# save the training instances to jsonl file
save_jsonlines("train.jsonl", train_instances)

In [None]:
train_instances[0]

In [None]:
# save the training data to GCS bucket
!gsutil cp train.jsonl {BUCKET_URI}/train/

- Validation data

In [None]:
val_file_path = "data_val10.jsonl"
val_instances = create_tuning_samples(val_file_path)
# save the training instances to jsonl file
save_jsonlines("val.jsonl", val_instances)

In [None]:
val_instances[0]

In [None]:
# save the validation data to GCS bucket
!gsutil cp val.jsonl {BUCKET_URI}/val/

- Below code transforms the jsonl format to following structure

`
[{'file_uri': '<GCS path for query image>',
 'ground_truth': '<Ground truth, image description'},
 ..
]
`

In [None]:
data_table = []
for instance in val_instances:
    data_table.append(
        {
            "file_uri": instance["contents"][0]["parts"][0]["fileData"]["fileUri"],
            "ground_truth": instance["contents"][1]["parts"][0]["text"],
        }
    )

In [None]:
data_table[0]

- The `data_table` is converted into dataframe of two columns, file_uri and ground_truth. The `ground_truth` will be compared with the model generated output

In [None]:
val_df = pd.DataFrame(data_table)
val_df

- Total `10` instances in validation data

## Visualization utils

- Function to visualize the query images stored in GCS bucket

In [None]:
# read a image bytes file present in GCS bucket


def read_image_bytes_from_gcs(bucket_name, blob_name):
    """Reads image bytes from a GCS bucket.

    Args:
      bucket_name: The name of the GCS bucket.
      blob_name: The name of the blob (file) within the bucket.

    Returns:
      The image bytes as a bytes object.
    """

    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(blob_name)

    image_bytes = blob.download_as_bytes()

    return image_bytes

## Evaluation Pre-Tuning

- Assign `gemini-2.0-flash-001` as base_model


In [None]:
base_model = "gemini-2.0-flash-001"

### Generation config

- Each call that you send to a model includes parameter values that control how the model generates a response. The model can generate different results for different parameter values
- <strong>Experiment</strong> with different parameter values to get the best values for the task

Refer to the following [link](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/prompts/adjust-parameter-values) for understanding different parameters

**Prompt** is a natural language request submitted to a language model to receive a response back

Some best practices include
  - Clearly communicate what content or information is most important
  - Structure the prompt:
    - Defining the role if using one. For example, You are an experienced UX designer at a top tech company
    - Include context and input data
    - Provide the instructions to the model
    - Add example(s) if you are using them

Refer to the following [link](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/prompts/prompt-design-strategies) for prompt design strategies.

### Task

***Task prompt:***

`
"<image>, Describe this image that captures the essence of it. "
`

***Query Image (image)***


In [None]:
query_image_uri = val_instances[0]["contents"][0]["parts"][0]["fileData"]["fileUri"]
blob_name = query_image_uri.replace(f"{BUCKET_URI}/", "")
img = read_image_bytes_from_gcs(BUCKET_NAME, blob_name)

# Display image bytes using pil python library
image = Image.open(io.BytesIO(img))
resized_img = image.resize((300, 300))
display(resized_img)

- Test on single instance

In [None]:
response = client.models.generate_content(
    model=base_model,
    contents=[
        types.Part.from_uri(file_uri=str(query_image_uri), mime_type="image/jpeg"),
        "Describe this image that captures the essence of it.",
    ],
    # Optional config
    config={
        "temperature": 0.0,
    },
)

print(response.text.strip())

- Ground truth

In [None]:
val_instances[0]["contents"][1]["parts"][0]["text"]

- Change prompt to get detailed description for the provided image

In [None]:
response = client.models.generate_content(
    model=base_model,
    contents=[
        types.Part.from_uri(file_uri=str(query_image_uri), mime_type="image/jpeg"),
        "Describe this image in detail that captures the essence of it.",
    ],
    # Optional config
    config={
        "temperature": 0.0,
    },
)

print(response.text.strip())

## Evaluation before model tuning

- Evaluate the Gemini model on the validation dataset before tuning it on the training dataset.

In [None]:
def get_prediction(query_image_uri, base_model):
    """Gets the prediction for a given instance.

    Args:
      query_image: The path to the query image.
      candidates: A list of paths to the candidate images.
      generation_model: The generation model to use for prediction.

    Returns:
      A string containing the prediction.
    """
    response = client.models.generate_content(
        model=base_model,
        contents=[
            types.Part.from_uri(file_uri=str(query_image_uri), mime_type="image/jpeg"),
            task_prompt,
        ],
        # Optional config
        config={
            "temperature": 0.0,
        },
    )

    return response.text.strip()

In [None]:
def run_eval(val_df, model=base_model):
    """Runs evaluation on the validation dataset.

    Args:
      val_df: The validation dataframe.
      generation_model: The generation model to use for evaluation.

    Returns:
      A list of predictions on val_df.
    """
    predictions = []
    for i, row in tqdm(val_df.iterrows(), total=val_df.shape[0]):
        try:
            prediction = get_prediction(row["file_uri"], model)
        except:
            time.sleep(30)
            prediction = get_prediction(row["file_uri"], model)
        predictions.append(prediction)
        time.sleep(1)
    return predictions

- Evaluate the Gemini model on the test dataset before tuning it on the training dataset.


<div class="alert alert-block alert-warning">
<b>⚠️ It will take ~1 min for the model to generate predictions on the provided validation dataset. ⚠️</b>
</div>

In [None]:
%%time
predictions = run_eval(val_df, model=base_model)

In [None]:
len(predictions)

In [None]:
val_df.loc[:, "basePredictions"] = predictions

In [None]:
val_df

### Evaluation metric

The type of metrics used for evaluation depends on the task that you are evaluating. The following table shows the supported tasks and the metrics used to evaluate each task:

| Task             | Metric(s)                     |
|-----------------|---------------------------------|
| Classification   | Micro-F1, Macro-F1, Per class F1 |
| Summarization    | ROUGE-L                         |
| Question Answering | Exact Match                     |
| Text Generation  | BLEU, ROUGE-L                   |


For this task, we'll using ROUGE metric.

- **Recall-Oriented Understudy for Gisting Evaluation (ROUGE)**: A metric used to evaluate the quality of automatic summaries of text. It works by comparing a generated summary to a set of reference summaries created by humans.

Now you can take the candidate and reference to evaluate the performance. In this case, ROUGE will give you:

- `rouge-1`, which measures unigram overlap
- `rouge-2`, which measures bigram overlap
- `rouge-l`, which measures the longest common subsequence

- *Recall vs. Precision*

    **Recall**, meaning it prioritizes how much of the information in the reference summaries is captured in the generated summary.

    **Precision**, which measures how much of the generated summary is relevant to the original text.

- Initialize `rouge_score` object

In [None]:
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)

- Define function to calculate rouge score

In [None]:
def get_rouge_score(groundTruth, prediction):
    """Function to compute rouge score.

    Args:
      groundTruth: The ground truth text.
      prediction: The predicted text.
    Returns:
      The rouge score.
    """
    scores = scorer.score(target=groundTruth, prediction=prediction)
    return scores

- Single instance evaluation

In [None]:
get_rouge_score(val_df.loc[0, "ground_truth"], val_df.loc[0, "basePredictions"])

In [None]:
def calculate_metrics(val_df, prediction_col="basePredictions"):
    """Function to compute rouge scores for all instances in the validation dataset.
    Args:
      val_df: The validation dataframe.
      prediction_col: The column name of the predictions.
    Returns:
      A dataframe containing the rouge scores.
    """
    records = []
    for row, instance in val_df.iterrows():
        scores = get_rouge_score(instance["ground_truth"], instance[prediction_col])
        records.append(
            {
                "rouge1_precision": scores.get("rouge1").precision,
                "rouge1_recall": scores.get("rouge1").recall,
                "rouge1_fmeasure": scores.get("rouge1").fmeasure,
                "rouge2_precision": scores.get("rouge2").precision,
                "rouge2_recall": scores.get("rouge2").recall,
                "rouge2_fmeasure": scores.get("rouge2").fmeasure,
                "rougeL_precision": scores.get("rougeL").precision,
                "rougeL_recall": scores.get("rougeL").recall,
                "rougeL_fmeasure": scores.get("rougeL").fmeasure,
            }
        )
    metrics = pd.DataFrame(records)
    return metrics

In [None]:
evaluation_df_stats = calculate_metrics(val_df, prediction_col="basePredictions")
evaluation_df_stats

In [None]:
print("Mean rougeL_precision is", evaluation_df_stats.rougeL_precision.mean())
print("Mean rougeL_recall is", evaluation_df_stats.rougeL_recall.mean())
print("Mean rougeL_fmeasure is", evaluation_df_stats.rougeL_fmeasure.mean())

## Fine-tune the model

You can create a supervised fine-tuning job by using the Google Gen AI SDK for Python.


When you run a supervised fine-tuning job, the model learns additional parameters that help it encode the necessary information to perform the desired task or learn the desired behavior. These parameters are used during inference. The output of the tuning job is a new model that combines the newly learned parameters with the original model.

**Tuning Job parameters**

- `source_model`: Specifies the base Gemini model version you want to fine-tune.
- `train_dataset`: Path to your training data in JSONL format.


 *Optional parameters*
 - `validation_dataset`: If provided, this data is used to evaluate the model during tuning.
 - `tuned_model_display_name`: Display name for the tuned model.

 *Hyperparameters*
 - `epochs`: The number of training epochs to run.
 - `learning_rate_multiplier`: A value to scale the learning rate during training.
 - `adapter_size` : Gemini 2.0 Flash supports Adapter length [1, 2, 4, 8], default value is 4.


**Note: The default hyperparameter settings are optimized for optimal performance based on rigorous testing and are recommended for initial use. Users may customize these parameters to address specific performance requirements.**

- Check out the [documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-use-supervised-tuning#tuning_hyperparameters) to learn more.
- [Gen AI SDK for tuning job](https://googleapis.github.io/python-genai/genai.html#genai.types.CreateTuningJobConfig)

In [None]:
tuned_model_display_name = "[DISPLAY NAME FOR TUNED MODEL]"  # @param {type:"string"}

training_dataset = {
    "gcs_uri": f"{BUCKET_URI}/train/train.jsonl",
}

validation_dataset = types.TuningValidationDataset(
    gcs_uri=f"{BUCKET_URI}/val/val.jsonl"
)


sft_tuning_job = client.tunings.tune(
    base_model=base_model,
    training_dataset=training_dataset,
    config=types.CreateTuningJobConfig(
        adapter_size="ADAPTER_SIZE_EIGHT",
        epoch_count=1,  # set to one to keep time and cost low
        tuned_model_display_name=tuned_model_display_name,
        validation_dataset=validation_dataset,
    ),
)

In [None]:
# Get the tuning job info.
tuning_job = client.tunings.get(name=sft_tuning_job.name)
tuning_job

**Note: Tuning time depends on several factors, such as training data size, number of epochs, learning rate multiplier, etc.**

<div class="alert alert-block alert-warning">
<b>⚠️ It will take 30-40 mins for the model tuning job to complete on the provided dataset and set configurations/hyperparameters. ⚠️</b>
</div>

In [None]:
%%time
# Wait for job completion

running_states = [
    "JOB_STATE_PENDING",
    "JOB_STATE_RUNNING",
]

while tuning_job.state.name in running_states:
    tuning_job = client.tunings.get(name=sft_tuning_job.name)
    time.sleep(10)

## Evaluation Post-tuning

- Evaluate the Gemini model on the validation dataset with tuned model.

In [None]:
tuned_model = tuning_job.tuned_model.endpoint
tuning_experiment_name = tuning_job.experiment

print("Tuned model experiment", tuning_experiment_name)
print("Tuned model endpoint resource name:", tuned_model)

- Get a prediction from tuned model

In [None]:
response = client.models.generate_content(
    model=tuned_model,
    contents=[
        types.Part.from_uri(file_uri=str(query_image_uri), mime_type="image/jpeg"),
        task_prompt,
    ],
    # Optional config
    config={
        "temperature": 0,
    },
)

print(response.text.strip())

- Evaluate the tuned model on entire validation set

<div class="alert alert-block alert-warning">
<b>⚠️ It will take ~1 min for the model to generate predictions on the provided validation dataset. ⚠️</b>
</div>

In [None]:
%%time
predictions_tuned = run_eval(val_df, model=tuned_model)

In [None]:
val_df.loc[:, "tunedPredictions"] = predictions_tuned

In [None]:
evaluation_df_post_tuning_stats = calculate_metrics(
    val_df, prediction_col="tunedPredictions"
)
evaluation_df_post_tuning_stats

- Improvement

In [None]:
evaluation_df_post_tuning_stats.rougeL_precision.mean()

In [None]:
improvement = round(
    (
        (
            evaluation_df_post_tuning_stats.rougeL_precision.mean()
            - evaluation_df_stats.rougeL_precision.mean()
        )
        / evaluation_df_stats.rougeL_precision.mean()
    )
    * 100,
    2,
)
print(
    f"Model tuning has improved the rougeL_precision by {improvement}% (result might differ based on each tuning iteration)"
)

In [None]:
# Save predicitons
predictions_all = val_df.to_csv("validation_pred.csv", index=False)

## Conclusion

Performance could be further improved:
- By adding more training samples. In general, improve your training data quality and/or quantity towards getting a more diverse and comprehensive dataset for your task
- By tuning the hyperparameters, such as epochs, learning rate multiplier or adapter size
  - To find the optimal number of epochs for your dataset, we recommend experimenting with different values. While increasing epochs can lead to better performance, it's important to be mindful of overfitting, especially with smaller datasets. If you see signs of overfitting, reducing the number of epochs can help mitigate the issue
- You may try different prompt structures/formats and opt for the one with better performance

## Cleaning up

To clean up all Google Cloud resources used in this project, you can [delete the Google Cloud
project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.


Otherwise, you can delete the individual resources you created in this tutorial.

Refer to this [instructions](https://cloud.google.com/vertex-ai/docs/tutorials/image-classification-custom/cleanup#delete_resources) to delete the resources from console.

In [None]:
# Delete Experiment.
delete_experiments = True
if delete_experiments:
    experiments_list = aiplatform.Experiment.list()
    for experiment in experiments_list:
        if experiment.resource_name == tuning_experiment_name:
            print(experiment.resource_name)
            experiment.delete()
            break

print("***" * 10)

# Delete Endpoint.
delete_endpoint = True
# If force is set to True, all deployed models on this
# Endpoint will be first undeployed.
if delete_endpoint:
    for endpoint in aiplatform.Endpoint.list():
        if endpoint.resource_name == tuned_model:
            print(endpoint.resource_name)
            endpoint.delete(force=True)
            break

print("***" * 10)

# Delete Cloud Storage Bucket.
delete_bucket = True
if delete_bucket:
    ! gsutil -m rm -r $BUCKET_URI