In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden TFVision With Image Classification

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_tfvision_image_classification.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_tfvision_image_classification.ipynb">
      <img alt="GitHub logo" src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates how to use [TFVision](https://github.com/tensorflow/models/blob/master/official/vision/MODEL_GARDEN.md) in Vertex AI Model Garden.

### Objective

* Train new models
  * Convert input data to training formats
  * Create [hyperparameter tuning jobs](https://cloud.google.com/vertex-ai/docs/training/hyperparameter-tuning-overview) to train new models
  * Find and export best models

* Test trained models
  * Upload models to model registry
  * Deploy uploaded models
  * Run predictions

* Cleanup resources

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.


## Before you begin

In [None]:
# @title Setup Google Cloud project

# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. **[Optional]** [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. "us") is not considered a match for a single region covered by the multi-region range (eg. "us-central1"). If not set, a unique GCS bucket will be created instead.

BUCKET_URI = "gs://"  # @param {type:"string"}

# @markdown 3. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.

REGION = ""  # @param {type:"string"}

# @markdown 4. If you want to run predictions with A100 80GB or H100 GPUs, we recommend using the regions listed below. **NOTE:** Make sure you have associated quota in selected regions. Click the links to see your current quota for each GPU type: [Nvidia A100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_a100_80gb_gpus), [Nvidia H100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h100_gpus). You can request for quota following the instructions at ["Request a higher quota"](https://cloud.google.com/docs/quota/view-manage#requesting_higher_quota).

# @markdown > | Machine Type | Accelerator Type | Recommended Regions |
# @markdown | ----------- | ----------- | ----------- |
# @markdown | a2-ultragpu-1g | 1 NVIDIA_A100_80GB | us-central1, us-east4, europe-west4, asia-southeast1, us-east4 |
# @markdown | a3-highgpu-2g | 2 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |
# @markdown | a3-highgpu-4g | 4 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |
# @markdown | a3-highgpu-8g | 8 NVIDIA_H100_80GB | us-central1, europe-west4, us-west1, asia-southeast1 |

! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git

import base64
import datetime
import importlib
import io
import json
import os
import subprocess
import uuid
from typing import Any, Dict, List, Union

import yaml
from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

common_util = importlib.import_module(
    "vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.common_util"
)

models, endpoints = {}, {}


# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
if not REGION:
    if not os.environ.get("GOOGLE_CLOUD_REGION"):
        raise ValueError(
            "REGION must be set. See"
            " https://cloud.google.com/vertex-ai/docs/general/locations for"
            " available cloud locations."
        )
    REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value yourself below.
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])

if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}"
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    ! gsutil mb -l {REGION} {BUCKET_URI}
else:
    assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
    shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            "Bucket region %s is different from notebook region %s"
            % (bucket_region, REGION)
        )
print(f"Using this GCS Bucket: {BUCKET_URI}")

STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")
MODEL_BUCKET = os.path.join(BUCKET_URI, "tfvision_image_classification")


# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

# Gets the default SERVICE_ACCOUNT.
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
print("Using this default Service Account:", SERVICE_ACCOUNT)


# Provision permissions to the SERVICE_ACCOUNT with the GCS bucket
! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT}:roles/storage.admin $BUCKET_NAME

! gcloud config set project $PROJECT_ID
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/storage.admin"
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/aiplatform.user"

CONFIG_DIR = os.path.join(BUCKET_URI, "config")
CHECKPOINT_BUCKET = os.path.join(BUCKET_URI, "ckpt")

# Only regions prefixed by "us", "asia", or "europe" are supported.
REGION_PREFIX = REGION.split("-")[0]
assert REGION_PREFIX in (
    "us",
    "europe",
    "asia",
), f'{REGION} is not supported. It must be prefixed by "us", "asia", or "europe".'


def upload_config_to_gcs(url):
    filename = os.path.basename(url)
    destination = os.path.join(CONFIG_DIR, filename)
    print("Copy", url, "to", destination)
    ! wget "$url" -O "$filename"
    ! gsutil cp "$filename" "$destination"


upload_config_to_gcs(
    "https://raw.githubusercontent.com/tensorflow/models/master/official/vision/configs/experiments/image_classification/imagenet_resnet50_gpu.yaml"
)
upload_config_to_gcs(
    "https://raw.githubusercontent.com/tensorflow/models/master/official/vision/configs/experiments/image_classification/imagenet_resnetrs50_i160_gpu.yaml"
)
upload_config_to_gcs(
    "https://raw.githubusercontent.com/tensorflow/models/master/official/projects/maxvit/configs/experiments/maxvit_base_imagenet_gpu.yaml"
)

# Define constants.
OBJECTIVE = "icn"

# Evaluation constants.
EVALUATION_METRIC = "accuracy"

## Training

This section trains model with the following steps:
1. Prepare data by converting the input data into training format.
2. Run hyperparameter tuning jobs to train new models.
3. Find and export best models.

In [None]:
# @title Prepare input data for training

# @markdown This section converts input data to training format, and splits to train/test/validation dataset with given split ratios and number of shards.

# @markdown Prepare data in the format as described [here](https://cloud.google.com/vertex-ai/docs/image-data/classification/prepare-data), and then convert them to the training formats as below:
# @markdown * `input_file_path`: The input file path for preparing data. Sample input file: `gs://cloud-samples-data/ai-platform/flowers/flowers.csv`
# @markdown * `input_file_type`: The input file type, can be "csv" or "jsonl".
# @markdown * `num_classes`: Number of classes in the dataset.
# @markdown * `split_ratio`: The proportion of data to split into train/validation/test, e.g. "0.8,0.1,0.1".
# @markdown * `num_shard`: The number of shards for train/validation/test, e.g. "10,10,10".

# This job will convert input data as training format, with given split ratios
# and number of shards on train/test/validation.

from google.cloud.aiplatform import hyperparameter_tuning as hpt

# Data converter constants.
DATA_CONVERTER_JOB_PREFIX = "data_converter"
DATA_CONVERTER_CONTAINER = f"{REGION_PREFIX}-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/data-converter"
DATA_CONVERTER_MACHINE_TYPE = "n1-highmem-8"

data_converter_job_name = common_util.get_job_name_with_datetime(
    DATA_CONVERTER_JOB_PREFIX + "_" + OBJECTIVE
)

input_file_path = "gs://cloud-samples-data/ai-platform/flowers/flowers.csv"  # @param {type:"string"} {isTemplate:true}
input_file_type = "csv"  # @param ["csv", "jsonl"]
num_classes = 5  # @param {type:"integer"}
split_ratio = "0.8,0.1,0.1"  # @param {type:"string"}
num_shard = "10,10,10"  # @param {type:"string"}
data_converter_output_dir = os.path.join(BUCKET_URI, data_converter_job_name)

worker_pool_specs = [
    {
        "machine_spec": {
            "machine_type": DATA_CONVERTER_MACHINE_TYPE,
        },
        "replica_count": 1,
        "container_spec": {
            "image_uri": DATA_CONVERTER_CONTAINER,
            "command": [],
            "args": [
                "--input_file_path=%s" % input_file_path,
                "--input_file_type=%s" % input_file_type,
                "--objective=%s" % OBJECTIVE,
                "--num_shard=%s" % num_shard,
                "--split_ratio=%s" % split_ratio,
                "--output_dir=%s" % data_converter_output_dir,
            ],
        },
    }
]

data_converter_custom_job = aiplatform.CustomJob(
    display_name=data_converter_job_name,
    project=PROJECT_ID,
    worker_pool_specs=worker_pool_specs,
    staging_bucket=STAGING_BUCKET,
)

data_converter_custom_job.run()

input_train_data_path = os.path.join(data_converter_output_dir, "train.tfrecord*")
input_validation_data_path = os.path.join(data_converter_output_dir, "val.tfrecord*")
label_map_path = os.path.join(data_converter_output_dir, "label_map.yaml")
print("input_train_data_path for training: ", input_train_data_path)
print("input_validation_data_path for training: ", input_validation_data_path)
print("label_map_path for prediction: ", label_map_path)

In [None]:
# @title Create and run Vertex AI custom job with hyperparameter tuning

# @markdown This section use Vertex AI SDK to create and run the hyperparameter tuning job with Vertex AI Model Garden Training Dockers.

# @markdown Select one of the following experiments:
# @markdown * [tfhub/EfficientNetV2](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imageclassification-efficientnet): `Efficientnetv2-m`
# @markdown * [tfvision/ViT](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imageclassification-vit): `ViT-ti16`, `ViT-s16`, `ViT-b16`, `ViT-l16`
# @markdown * [Proprietary/MaxViT](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imageclassification-proprietary-maxvit): `MaxViT`

# Input train and validation datasets can be found from the section above
# `Convert input data for training`.
# Set prepared datasets if exists.
# input_train_data_path = ''
# input_validation_data_path = ''

# Training constants.
TRAINING_JOB_PREFIX = "train"
TRAIN_CONTAINER_URI = f"{REGION_PREFIX}-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/tfvision-oss"
TRAIN_MACHINE_TYPE = "g2-standard-4"
TRAIN_ACCELERATOR_TYPE = "NVIDIA_L4"
TRAIN_NUM_GPU = 1

experiment = "Efficientnetv2-m"  # @param ["Efficientnetv2-m","ViT-ti16","ViT-s16","ViT-b16","ViT-l16", "MaxViT"]

train_job_name = common_util.get_job_name_with_datetime(
    TRAINING_JOB_PREFIX + "_" + OBJECTIVE
)
model_dir = os.path.join(BUCKET_URI, train_job_name)

# The arguments here are mainly for test purposes. Kindly update them
# to get better performances.
common_args = {
    "input_train_data_path": input_train_data_path,
    "input_validation_data_path": input_validation_data_path,
    "objective": OBJECTIVE,
    "model_dir": model_dir,
    "num_classes": num_classes,
    "global_batch_size": 4,
    "prefetch_buffer_size": 32,
    "train_steps": 2000,
    "input_size": "224,224",
}

# Arguments for different experiments.
experiment_container_args_dict = {
    "Efficientnetv2-m": dict(
        common_args,
        **{
            "experiment": "hub_model",
        },
    ),
    "ViT-ti16": dict(
        common_args,
        **{
            "experiment": "deit_imagenet_pretrain",
            "model_name": "vit-ti16",
            "init_checkpoint": "https://storage.googleapis.com/tf_model_garden/vision/vit/vit-deit-imagenet-ti16.tar.gz",
            "input_size": "224,224",
        },
    ),
    "ViT-s16": dict(
        common_args,
        **{
            "experiment": "deit_imagenet_pretrain",
            "model_name": "vit-s16",
            "init_checkpoint": "https://storage.googleapis.com/tf_model_garden/vision/vit/vit-deit-imagenet-s16.tar.gz",
            "input_size": "224,224",
        },
    ),
    "ViT-b16": dict(
        common_args,
        **{
            "experiment": "deit_imagenet_pretrain",
            "model_name": "vit-b16",
            "init_checkpoint": "https://storage.googleapis.com/tf_model_garden/vision/vit/vit-deit-imagenet-b16.tar.gz",
            "input_size": "224,224",
        },
    ),
    "ViT-l16": dict(
        common_args,
        **{
            "experiment": "deit_imagenet_pretrain",
            "model_name": "vit-l16",
            "init_checkpoint": "https://storage.googleapis.com/tf_model_garden/vision/vit/vit-deit-imagenet-l16.tar.gz",
            "input_size": "224,224",
        },
    ),
    "MaxViT": dict(
        common_args,
        **{
            "experiment": "maxvit_imagenet",
            "config_file": os.path.join(CONFIG_DIR, "maxvit_base_imagenet_gpu.yaml"),
        },
    ),
}
experiment_container_args = experiment_container_args_dict[experiment]


def upload_checkpoint_to_gcs(checkpoint_url):
    filename = os.path.basename(checkpoint_url)
    checkpoint_name = filename.replace(".tar.gz", "")
    print("Download checkpoint from", checkpoint_url, "and store to", CHECKPOINT_BUCKET)
    ! wget $checkpoint_url -O $filename
    ! mkdir -p $checkpoint_name
    ! tar -xvzf $filename -C $checkpoint_name

    # Search for relative path to the checkpoint.
    checkpoint_path = None
    for root, dirs, files in os.walk(checkpoint_name):
        for file in files:
            if file.endswith(".index"):
                checkpoint_path = os.path.join(root, os.path.splitext(file)[0])
                checkpoint_path = os.path.relpath(checkpoint_path, checkpoint_name)
                break

    ! gsutil cp -r $checkpoint_name $CHECKPOINT_BUCKET/
    checkpoint_uri = os.path.join(CHECKPOINT_BUCKET, checkpoint_name, checkpoint_path)
    print("Checkpoint uploaded to", checkpoint_uri)
    return checkpoint_uri


# Copy checkpoint to GCS bucket if specified.
init_checkpoint = experiment_container_args.get("init_checkpoint")
if init_checkpoint:
    experiment_container_args["init_checkpoint"] = upload_checkpoint_to_gcs(
        init_checkpoint
    )

# Use container that supports MaxViT
if experiment == "MaxViT":
    TRAIN_CONTAINER_URI = f"{REGION_PREFIX}-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/tfvision-oss-v2"

worker_pool_specs = [
    {
        "machine_spec": {
            "machine_type": TRAIN_MACHINE_TYPE,
            "accelerator_type": TRAIN_ACCELERATOR_TYPE,
            # Each training job uses TRAIN_NUM_GPU GPUs.
            "accelerator_count": TRAIN_NUM_GPU,
        },
        "replica_count": 1,
        "container_spec": {
            "image_uri": TRAIN_CONTAINER_URI,
            "args": [
                "--mode=train_and_eval",
                "--params_override=runtime.num_gpus=%d" % TRAIN_NUM_GPU,
            ]
            + ["--{}={}".format(k, v) for k, v in experiment_container_args.items()],
        },
    }
]

metric_spec = {"model_performance": "maximize"}


LEARNING_RATES = [5e-4, 1e-3]
# Models will be trained with each learning rate separately and max trial count is the number of learning rates.
MAX_TRIAL_COUNT = len(LEARNING_RATES)
parameter_spec = {
    "learning_rate": hpt.DiscreteParameterSpec(values=LEARNING_RATES, scale="linear"),
}

print(worker_pool_specs, metric_spec, parameter_spec)

# Check quota.
common_util.check_quota(
    project_id=PROJECT_ID,
    region=REGION,
    accelerator_type=TRAIN_ACCELERATOR_TYPE,
    accelerator_count=1,
    is_for_training=True,
)


# Add labels for the finetuning job.
labels = {
    "mg-source": "notebook",
    "mg-notebook-name": "model_garden_tfvision_image_classification.ipynb".split(".")[
        0
    ],
}

labels["mg-tune"] = "publishers-google-models-tfvision"
versioned_model_id = experiment.lower().replace("_", "-")
labels["versioned-mg-tune"] = f"{labels['mg-tune']}-{versioned_model_id}"

# Run the hyperparameter job.
train_custom_job = aiplatform.CustomJob(
    display_name=train_job_name,
    project=PROJECT_ID,
    worker_pool_specs=worker_pool_specs,
    staging_bucket=STAGING_BUCKET,
    labels=labels,
)

train_hpt_job = aiplatform.HyperparameterTuningJob(
    display_name=train_job_name,
    custom_job=train_custom_job,
    metric_spec=metric_spec,
    parameter_spec=parameter_spec,
    max_trial_count=MAX_TRIAL_COUNT,
    parallel_trial_count=MAX_TRIAL_COUNT,
    project=PROJECT_ID,
    search_algorithm=None,
)

train_hpt_job.run()

print("experiment is: ", experiment)
print("model_dir is: ", model_dir)

In [None]:
# @title Export best models as TF Saved Model format

# @markdown This section exports best model.

# Export models from TF checkpoints to TF saved model format.
# model_dir is from the section above.

# Export constants.
EXPORT_JOB_PREFIX = "export"
EXPORT_CONTAINER_URI = f"{REGION_PREFIX}-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/tfvision-model-export"
EXPORT_MACHINE_TYPE = "n1-highmem-8"


def get_best_trial(model_dir, max_trial_count, evaluation_metric):
    best_trial_dir = ""
    best_trial_evaluation_results = {}
    best_performance = -1

    for i in range(max_trial_count):
        current_trial = i + 1
        current_trial_dir = os.path.join(model_dir, "trial_" + str(current_trial))
        current_trial_best_ckpt_dir = os.path.join(current_trial_dir, "best_ckpt")
        current_trial_best_ckpt_evaluation_filepath = os.path.join(
            current_trial_best_ckpt_dir, "info.json"
        )
        ! gsutil cp $current_trial_best_ckpt_evaluation_filepath .
        with open("info.json", "r") as f:
            eval_metric_results = json.load(f)
            current_performance = eval_metric_results[evaluation_metric]
            if current_performance > best_performance:
                best_performance = current_performance
                best_trial_dir = current_trial_dir
                best_trial_evaluation_results = eval_metric_results
    print("best_trial_dir: ", current_trial_best_ckpt_evaluation_filepath)
    return best_trial_dir, best_trial_evaluation_results


best_trial_dir, best_trial_evaluation_results = get_best_trial(
    model_dir, MAX_TRIAL_COUNT, EVALUATION_METRIC
)
print("best_trial_dir: ", best_trial_dir)
print("best_trial_evaluation_results: ", best_trial_evaluation_results)

worker_pool_specs = [
    {
        "machine_spec": {
            "machine_type": EXPORT_MACHINE_TYPE,
        },
        "replica_count": 1,
        "container_spec": {
            "image_uri": EXPORT_CONTAINER_URI,
            "command": [],
            "args": [
                "--objective=%s" % OBJECTIVE,
                "--input_image_size=%s" % experiment_container_args["input_size"],
                "--experiment=%s" % experiment_container_args["experiment"],
                "--config_file=%s/params.yaml" % best_trial_dir,
                "--checkpoint_path=%s/best_ckpt" % best_trial_dir,
                "--export_dir=%s/best_model" % model_dir,
            ],
        },
    }
]

model_export_name = common_util.get_job_name_with_datetime(
    EXPORT_JOB_PREFIX + "_" + OBJECTIVE
)
model_export_custom_job = aiplatform.CustomJob(
    display_name=model_export_name,
    project=PROJECT_ID,
    worker_pool_specs=worker_pool_specs,
    staging_bucket=STAGING_BUCKET,
)

model_export_custom_job.run()

print("best model is saved to: ", os.path.join(model_dir, "best_model"))

## Deployment

In [None]:
# @title Upload and deploy models

# @markdown This section uploads and deploy models to model registry for online prediction. This example uses the exported best model from "Train new models" section.

PREDICTION_CONTAINER_URI = f"{REGION_PREFIX}-docker.pkg.dev/vertex-ai-restricted/prediction/tf_opt-gpu.2-11:latest"
SERVING_CONTAINER_ARGS = ["--allow_precompilation", "--allow_compression"]
PREDICTION_ACCELERATOR_TYPE = "NVIDIA_L4"
PREDICTION_MACHINE_TYPE = "g2-standard-12"
UPLOAD_JOB_PREFIX = "upload"
DEPLOY_JOB_PREFIX = "deploy"

trained_model_dir = os.path.join(model_dir, "best_model/saved_model")
upload_job_name = common_util.get_job_name_with_datetime(
    UPLOAD_JOB_PREFIX + "_" + OBJECTIVE
)

serving_env = {
    "MODEL_ID": "tensorflow-hub-efficientnetv2",
    "DEPLOY_SOURCE": "notebook",
}
match experiment:
    case "Efficientnetv2-m":
        publisher_model_id = "imageclassification-efficientnet"
    case "ViT-ti16" | "ViT-s16" | "ViT-b16" | "ViT-l16":
        publisher_model_id = "imageclassification-vit"
    case "MaxViT":
        publisher_model_id = "imageclassification-maxvit"
    case _:
        raise ValueError(f"Unknown experiment: {experiment}")

models["model_icn"] = aiplatform.Model.upload(
    display_name=upload_job_name,
    artifact_uri=trained_model_dir,
    serving_container_image_uri=PREDICTION_CONTAINER_URI,
    serving_container_args=SERVING_CONTAINER_ARGS,
    serving_container_environment_variables=serving_env,
    model_garden_source_model_name=(
        f"publishers/google/models/{publisher_model_id}"
    ),
)

models["model_icn"].wait()

print("The uploaded model name is: ", upload_job_name)

deploy_model_name = common_util.get_job_name_with_datetime(
    DEPLOY_JOB_PREFIX + "_" + OBJECTIVE
)
print("The deployed job name is: ", deploy_model_name)

common_util.check_quota(
    project_id=PROJECT_ID,
    region=REGION,
    accelerator_type=PREDICTION_ACCELERATOR_TYPE,
    accelerator_count=1,
    is_for_training=False,
)

endpoints["endpoint_icn"] = models["model_icn"].deploy(
    deployed_model_display_name=deploy_model_name,
    machine_type=PREDICTION_MACHINE_TYPE,
    traffic_split={"0": 100},
    accelerator_type=PREDICTION_ACCELERATOR_TYPE,
    accelerator_count=1,
    min_replica_count=1,
    max_replica_count=1,
    system_labels={
        "NOTEBOOK_NAME": "model_garden_tfvision_image_classification.ipynb"
    },
)

endpoint_id = endpoints["endpoint_icn"].name
print("endpoint id is: ", endpoint_id)

## Predict

In [None]:
# @title Run predictions

# @markdown Once deployment succeeds, you can send image to the endpoint for online prediction.

# @markdown `test_filepath`: gcs uri to the test image file. The uri should start with "gs://".

# endpoint_id was generated in the section above (`Upload and deploy models`).
endpoint_id = endpoints["endpoint_icn"].name

test_filepath = "gs://cloud-samples-data/ai-platform/flowers/roses/9423755543_edb35141a3_n.jpg"  # @param {type:"string"} {isTemplate:true}


def get_label_map(label_map_yaml_filepath: str) -> Dict[int, str]:
    """Returns class id to label mapping given a filepath to the label map.

    Args:
      label_map_yaml_filepath: A string of label map yaml file path.

    Returns:
      A dictionary of class id to label mapping.
    """
    label_map_filename = os.path.basename(label_map_yaml_filepath)
    subprocess.check_output(
        ["gsutil", "cp", label_map_yaml_filepath, label_map_filename],
        stderr=subprocess.STDOUT,
    )
    with open(label_map_filename, "rb") as input_file:
        label_map = yaml.safe_load(input_file.read())["label_map"]
        return label_map


def get_prediction_instances(test_filepath: str, new_width: int = -1) -> Any:
    """Generate instance from image path to pass to Vertex AI Endpoint for prediction.

    Args:
      test_filepath: A string of test image path.
      new_width: An integer of new image width.

    Returns:
      A list of instances.
    """
    if new_width <= 0:
        test_file = os.path.basename(test_filepath)
        subprocess.check_output(
            ["gsutil", "cp", test_filepath, test_file], stderr=subprocess.STDOUT
        )
        with open(test_file, "rb") as input_file:
            encoded_string = base64.b64encode(input_file.read()).decode("utf-8")
    else:
        img = common_util.load_img(test_filepath)
        width, height = img.size
        print("original input image size: ", width, " , ", height)
        new_height = int(height * new_width / width)
        new_img = img.resize((new_width, new_height))
        print("resized input image size: ", new_width, " , ", new_height)
        buffered = io.BytesIO()
        new_img.save(buffered, format="JPEG")
        encoded_string = base64.b64encode(buffered.getvalue()).decode("utf-8")

        instances = [
            {
                "encoded_image": {"b64": encoded_string},
            }
        ]
        return instances


# If the input image is too large, we will resize it for prediction.
instances = get_prediction_instances(test_filepath, new_width=1000)

# The label map file was generated from the section above (`Convert input data for training`).
label_map = get_label_map(label_map_path)


def predict_custom_trained_model(
    project: str,
    endpoint_id: str,
    instances: Union[Dict, List[Dict]],
    location: str = "us-central1",
):
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": f"{location}-aiplatform.googleapis.com"}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
    parameters_dict = {}
    parameters = json_format.ParseDict(parameters_dict, Value())
    endpoint = client.endpoint_path(
        project=project, location=location, endpoint=endpoint_id
    )
    response = client.predict(
        endpoint=endpoint, instances=instances, parameters=parameters
    )
    return response.predictions, response.deployed_model_id


predictions, _ = predict_custom_trained_model(
    project=PROJECT_ID, location=REGION, endpoint_id=endpoint_id, instances=instances
)

probs = dict(predictions[0])["probs"]
max_prob = max(probs)
max_index = probs.index(max_prob)
print("The test image: ", test_filepath)
print("max_prob: ", max_prob, ", for label: ", label_map[max_index])
img = common_util.load_img(test_filepath)
common_util.display_image(img)

## Clean up resources

In [None]:
# @title Clean up training jobs, models, endpoints and buckets

try:
    # Delete custom and hpt jobs.
    if data_converter_custom_job.list(
        filter=f'display_name="{data_converter_job_name}"'
    ):
        data_converter_custom_job.delete()
    if train_hpt_job.list(filter=f'display_name="{train_job_name}"'):
        train_hpt_job.delete()
    if model_export_custom_job.list(filter=f'display_name="{model_export_name}"'):
        model_export_custom_job.delete()
except Exception as e:
    print(e)

# @markdown  Delete the experiment models and endpoints to recycle the resources
# @markdown  and avoid unnecessary continuous charges that may incur.

# Undeploy model and delete endpoint.
for endpoint in endpoints.values():
    endpoint.delete(force=True)

# Delete models.
for model in models.values():
    model.delete()

delete_bucket = False  # @param {type:"boolean"}
if delete_bucket:
    ! gsutil -m rm -r $BUCKET_NAME