In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Image Warehouse SDK demo

<table align="left">

  <td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/vision/image_warehouse_sdk.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> Run in Colab
    </a>
  </td>
  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/vision/image_warehouse_sdk.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>                                                                                         
  <td>
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/notebooks/community/vision/image_warehouse_sdk.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
      Open in Vertex AI Workbench
    </a>
  </td>
</table>

**_NOTE_**: This notebook has been tested in the following environment:

* Python version = 3.10

## Overview

Learn how to build a [Image Warehouse](https://cloud.google.com/vision-ai/docs) step by step by using SDK.

### Objective
The objective is to demostrate how to use Image Warehouse for image data ingestion and perform similarity search given text query or image as inputs. It contains examples using critical Warehouse APIs and the pipeline to perform E2E data ingestion and search journey. The colab builds a Warehouse Corpus with thousands of images ingested, analyzed and indexed; and an Index Endpoint to perform search over the images. The CUJ is as following:

* Create Corpus
* Create Data Schema
* Import Assets
* Analyze Corpus
* Create Index
* Create Index Endpoint
* Deploy Index
* Perform Search
* Cleanup


### Dataset
The dataset used in this demo is publicly accessible at [gs://cloud-samples-data/ai-platform/flowers](https://pantheon.corp.google.com/storage/browser/cloud-samples-data/ai-platform/flowers). It contains 3670 images of five kinds of flowers.

The metadata file are publicly accessible at [gs://cloud-samples-data/vertex-ai-vision/warehouse/demo.jsonl](https://pantheon.corp.google.com/storage/browser/_details/cloud-samples-data/vertex-ai-vision/warehouse/demo.jsonl). It contains the annotations for each image file. The colab scans the metadata file and import the images and annotations into the warehouse.

### Costs

This tutorial uses billable components of Google Cloud:

Vertex AI Vision ([Pricing](https://cloud.google.com/vision-ai/pricing))


## Installation

Install the following packages required to execute this notebook.


In [None]:
!gsutil cp gs://visionai-artifacts/visionai-0.0.6-py3-none-any.whl .
!pip install visionai-0.0.6-py3-none-any.whl --force-reinstall
!pip install ipywidgets requests

### Colab only: Uncomment the following cell to restart the kernel.

In [None]:
# import IPython

# app = IPython.Application.instance()
# app.kernel.do_shutdown(True)

## Before you begin

### Set up your Google Cloud project

**The following steps are required, regardless of your notebook environment.**

1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.

2. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).


#### Set your project ID

**If you don't know your project ID**, try the following:
* Run `gcloud config list`.
* Run `gcloud projects list`.
* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)

In [None]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}
# Set the project id
! gcloud config set project {PROJECT_ID}

### Authenticate your Google Cloud account

Depending on your Jupyter environment, you may have to manually authenticate. Follow the relevant instructions below.

**1. Vertex AI Workbench**
* Do nothing as you are already authenticated.

**2. Local JupyterLab instance, uncomment and run:**

In [None]:
# ! gcloud auth login

**3. Colab, uncomment and run:**

In [None]:
# from google.colab import auth
# auth.authenticate_user()

### Set Up Other Constants

In [None]:
PROJECT_NUMBER_STR = !gcloud projects describe $PROJECT_ID --format="value(projectNumber)"
PROJECT_NUMBER = int(PROJECT_NUMBER_STR[0])

# Only us-central1 is supported.
REGION = "us-central1"

CORPUS_DISPLAY_NAME = "iwh demo corpus"  # @param {type: "string"}
CORPUS_DESCRIPTION = "iwh demo corpus"  # @param {type: "string"}

# External users can only access PROD environment.
ENV = "PROD"

### Enable API

In [None]:
!gcloud services enable "visionai.googleapis.com"

### Import Libraries

In [None]:
import math
import time

import ipywidgets
import requests
from IPython.display import display
from ipywidgets import GridspecLayout
from visionai.python.gapic.visionai import visionai_v1
from visionai.python.net import channel

## Create a Warehouse client

In [None]:
warehouse_endpoint = channel.get_warehouse_service_endpoint(channel.Environment[ENV])
warehouse_client = visionai_v1.WarehouseClient(
    client_options={"api_endpoint": warehouse_endpoint}
)

## Create a Corpus

In [None]:
# Set CORPUS_NAME to empty string to create new corpus
CORPUS_NAME = ""  # @param {type: "string"}

if CORPUS_NAME == "":
    search_capability = visionai_v1.SearchCapability(
        type_=visionai_v1.SearchCapability.Type.EMBEDDING_SEARCH
    )
    operation = warehouse_client.create_corpus(
        visionai_v1.CreateCorpusRequest(
            parent=f"projects/{PROJECT_NUMBER}/locations/{REGION}",
            corpus=visionai_v1.Corpus(
                display_name=CORPUS_DISPLAY_NAME,
                description=CORPUS_DESCRIPTION,
                type_=visionai_v1.Corpus.Type.IMAGE,
                search_capability_setting=visionai_v1.SearchCapabilitySetting(
                    search_capabilities=[search_capability]
                ),
            ),
        )
    )
    print("Wait for corpus operation:", operation.operation)

    print("Created corpus ", operation.result(timeout=7200))
    corpus_name = operation.result().name
    print("Corpus created:", corpus_name)
else:
    corpus_name = CORPUS_NAME
    print("Corpus: ", corpus_name)

## Create DataSchema

In [None]:
# Set SCHEMA_NAME_* to empty strings to create new schemas.
SCHEMA_NAME_WIDTH = ""  # @param {type: "string"}
SCHEMA_NAME_HEIGHT = ""  # @param {type: "string"}
SCHEMA_NAME_ASPECT = ""  # @param {type: "string"}
SCHEMA_NAME_CREATOR = ""  # @param {type: "string"}

if SCHEMA_NAME_WIDTH == "":
    schema_width = warehouse_client.create_data_schema(
        visionai_v1.CreateDataSchemaRequest(
            parent=corpus_name,
            data_schema=visionai_v1.DataSchema(
                key="width",
                schema_details=visionai_v1.DataSchemaDetails(
                    type_=visionai_v1.DataSchemaDetails.DataType.STRING,
                    granularity=visionai_v1.DataSchemaDetails.Granularity.GRANULARITY_ASSET_LEVEL,
                    search_strategy=visionai_v1.DataSchemaDetails.SearchStrategy(
                        search_strategy_type=visionai_v1.DataSchemaDetails.SearchStrategy.SearchStrategyType.EXACT_SEARCH
                    ),
                ),
            ),
        )
    )
    print(schema_width)
    schema_name_width = schema_width.name
else:
    schema_name_width = SCHEMA_NAME_WIDTH

if SCHEMA_NAME_HEIGHT == "":
    schema_height = warehouse_client.create_data_schema(
        visionai_v1.CreateDataSchemaRequest(
            parent=corpus_name,
            data_schema=visionai_v1.DataSchema(
                key="height",
                schema_details=visionai_v1.DataSchemaDetails(
                    type_=visionai_v1.DataSchemaDetails.DataType.STRING,
                    granularity=visionai_v1.DataSchemaDetails.Granularity.GRANULARITY_ASSET_LEVEL,
                    search_strategy=visionai_v1.DataSchemaDetails.SearchStrategy(
                        search_strategy_type=visionai_v1.DataSchemaDetails.SearchStrategy.SearchStrategyType.EXACT_SEARCH
                    ),
                ),
            ),
        )
    )
    print(schema_height)
    schema_name_height = schema_height.name
else:
    schema_name_height = SCHEMA_NAME_HEIGHT

if SCHEMA_NAME_ASPECT == "":
    schema_aspect = warehouse_client.create_data_schema(
        visionai_v1.CreateDataSchemaRequest(
            parent=corpus_name,
            data_schema=visionai_v1.DataSchema(
                key="aspect-ratio",
                schema_details=visionai_v1.DataSchemaDetails(
                    type_=visionai_v1.DataSchemaDetails.DataType.STRING,
                    granularity=visionai_v1.DataSchemaDetails.Granularity.GRANULARITY_ASSET_LEVEL,
                    search_strategy=visionai_v1.DataSchemaDetails.SearchStrategy(
                        search_strategy_type=visionai_v1.DataSchemaDetails.SearchStrategy.SearchStrategyType.EXACT_SEARCH
                    ),
                ),
            ),
        )
    )
    print(schema_aspect)
    schema_name_aspect = schema_aspect.name
else:
    schema_name_aspect = SCHEMA_NAME_ASPECT

if SCHEMA_NAME_CREATOR == "":
    schema_creator = warehouse_client.create_data_schema(
        visionai_v1.CreateDataSchemaRequest(
            parent=corpus_name,
            data_schema=visionai_v1.DataSchema(
                key="creator",
                schema_details=visionai_v1.DataSchemaDetails(
                    type_=visionai_v1.DataSchemaDetails.DataType.STRING,
                    granularity=visionai_v1.DataSchemaDetails.Granularity.GRANULARITY_ASSET_LEVEL,
                    search_strategy=visionai_v1.DataSchemaDetails.SearchStrategy(
                        search_strategy_type=visionai_v1.DataSchemaDetails.SearchStrategy.SearchStrategyType.EXACT_SEARCH
                    ),
                ),
            ),
        )
    )
    print(schema_creator)
    schema_name_creator = schema_creator.name
else:
    schema_name_creator = SCHEMA_NAME_CREATOR

## Import Assets

In [None]:
# Upload images into a gcs bucket and prepare the input gcs file.

# Set IMPORT_ASSET to True to import assets.
IMPORT_ASSET = True  # @param {type: "boolean"}
INPUT_GCS_FILE = "gs://cloud-samples-data/vertex-ai-vision/warehouse/demo.jsonl"  # @param {type: "string"}

if IMPORT_ASSET:
    import_lro = warehouse_client.import_assets(
        visionai_v1.ImportAssetsRequest(
            parent=f"{corpus_name}",
            assets_gcs_uri=f"{INPUT_GCS_FILE}",
        )
    )
    print("Wait for import operation: ", import_lro.operation)
    while not import_lro.done():
        time.sleep(10)
    print("Import operation done: ", import_lro.operation)

##  Analyze Corpus

In [None]:
# Set ANALYZE_CORPUS to True to analyze all assets in the corpus
ANALYZE_CORPUS = True  # @param {type: "boolean"}

if ANALYZE_CORPUS:
    analyze_lro = warehouse_client.analyze_corpus(
        visionai_v1.AnalyzeCorpusRequest(
            name=f"{corpus_name}",
        )
    )
    print("Wait for analyze operation: ", analyze_lro.operation)
    while not analyze_lro.done():
        time.sleep(10)
    print("Analyze operation done: ", analyze_lro.operation)

## Create and deploy Index

### Create Index

In [None]:
# Set INDEX_NAME to empty string to create a new index
INDEX_NAME = ""  # @param {type: "string"}

if INDEX_NAME == "":
    IMAGE_INDEX_ID = "image-index-demo"
    index_lro = warehouse_client.create_index(
        visionai_v1.CreateIndexRequest(
            parent=corpus_name,
            index_id=f"{IMAGE_INDEX_ID}",
            index=visionai_v1.Index(
                entire_corpus=True,
                display_name="demo index",
                description="demo index",
            ),
        )
    )
    print("Wait for index operation:", index_lro.operation)

    print("Created index ", index_lro.result(timeout=10800))
    index_name = index_lro.result().name
    print("Index created:", index_name)
else:
    index_name = INDEX_NAME

### Create Index Endpoint

In [None]:
# Set INDEX_NAME to empty string to create a new index
ENDPOINT_NAME = ""  # @param {type: "string"}

if ENDPOINT_NAME == "":
    ENDPOINT_ID = "search-endpoint-demo"
    endpoint_lro = warehouse_client.create_index_endpoint(
        visionai_v1.CreateIndexEndpointRequest(
            parent=f"projects/{PROJECT_NUMBER}/locations/{REGION}",
            index_endpoint_id=f"{ENDPOINT_ID}",
            index_endpoint=visionai_v1.IndexEndpoint(
                display_name="demo index endpoint",
                description="demo index endpoint",
            ),
        )
    )
    print("Wait for endpoint operation:", endpoint_lro.operation)

    print("Created endpoint ", endpoint_lro.result(timeout=7200))
    endpoint_name = endpoint_lro.result().name
    print("Endpoint created:", endpoint_name)
else:
    endpoint_name = ENDPOINT_NAME

### Deploy Index

In [None]:
# Set DEPLOY_INDEX to True to deploy the index to the endpoint
DEPLOY_INDEX = True  # @param {type: "boolean"}

if DEPLOY_INDEX:
    deploy_lro = warehouse_client.deploy_index(
        visionai_v1.DeployIndexRequest(
            index_endpoint=endpoint_name,
            deployed_index=visionai_v1.DeployedIndex(
                index=index_name,
            ),
        )
    )
    print("Wait for deploy operation:", deploy_lro.operation)

    print(deploy_lro.result(timeout=7200))
    print("Deployed Index: ", deploy_lro.operation)

## Search

### Util for rending images

In [None]:
def RenderImages(cols=5, image_uris=[]):
    assert len(image_uris) > 0
    assert cols > 0
    rows = math.floor((len(image_uris) - 1) / cols) + 1
    grid = GridspecLayout(rows, cols)
    for i in range(rows):
        for j in range(cols):
            index = i * cols + j
            if index >= len(image_uris):
                break
            grid[i, j] = ipywidgets.Image(
                value=requests.get(image_uris[index]).content, width=200
            )
    display(grid)

### Search by text

In [None]:
MAX_RESULTS = 10  # @param {type: "integer"} Set to 0 to allow all results.
QUERY = "multiple purple tulips"  # @param {type: "string"}
print("endpoint_name:", endpoint_name)
results = warehouse_client.search_index_endpoint(
    visionai_v1.SearchIndexEndpointRequest(
        index_endpoint=endpoint_name,
        text_query=QUERY,
    ),
)

results_cnt = 0
asset_names = []
for r in results:
    asset_names.append(r.asset)
    results_cnt += 1
    if results_cnt >= MAX_RESULTS:
        break

uris = list(
    map(
        lambda asset_name: warehouse_client.generate_retrieval_url(
            visionai_v1.GenerateRetrievalUrlRequest(
                name=asset_name,
            )
        ).signed_uri,
        asset_names,
    )
)

RenderImages(image_uris=uris)

### Search by image

In [None]:
IMAGE_GCS_FILE = "gs://cloud-samples-data/ai-platform/flowers/roses/14312910041_b747240d56_n.jpg"  # @#param {type: "string"} example: gs://iwh_fishfood/sample-image.jpg
MAX_RESULTS = 10  # @#param {type: "integer"} Set to 0 to allow all results.
IMAGE_FILE = "/tmp/sample-image.jpg"
!gsutil cp $IMAGE_GCS_FILE $IMAGE_FILE

with open(IMAGE_FILE, "rb") as f:
    image_content = f.read()
grid = GridspecLayout(1, 1)
grid[0, 0] = ipywidgets.Image(value=image_content, width=200)

print("Query image:")
display(grid)

results = warehouse_client.search_index_endpoint(
    visionai_v1.SearchIndexEndpointRequest(
        index_endpoint=endpoint_name,
        image_query=visionai_v1.ImageQuery(
            input_image=image_content,
        ),
    ),
)

results_cnt = 0
asset_names = []
for r in results:
    asset_names.append(r.asset)
    results_cnt += 1
    if results_cnt >= MAX_RESULTS:
        break

uris = list(
    map(
        lambda asset_name: warehouse_client.generate_retrieval_url(
            visionai_v1.GenerateRetrievalUrlRequest(
                name=asset_name,
            )
        ).signed_uri,
        asset_names,
    )
)

print("Search results:")
RenderImages(image_uris=uris)

### Adding metadata filters

In [None]:
IMAGE_GCS_FILE = "gs://cloud-samples-data/ai-platform/flowers/roses/14312910041_b747240d56_n.jpg"  # @param {type: "string"} example: gs://iwh_fishfood/sample-image.jpg
MAX_RESULTS = 10  # @param {type: "integer"} Set to 0 to allow all results.
IMAGE_FILE = "/tmp/sample-image.jpg"
!gsutil cp $IMAGE_GCS_FILE $IMAGE_FILE

with open(IMAGE_FILE, "rb") as f:
    image_content = f.read()
grid = GridspecLayout(1, 1)
grid[0, 0] = ipywidgets.Image(value=image_content, width=200)

print("Query image:")
display(grid)

aspect_ratios = ["1.3", "1.4"]  # @#param {type: "list", itemType: "string"}
aspect_ratio_criteria = visionai_v1.types.StringArray(txt_values=aspect_ratios)
aspect_ratio_filter = visionai_v1.Criteria(
    field="aspect-ratio", text_array=aspect_ratio_criteria
)

# Define creator filter criteria
creator = ["Saige Fuentes"]  # @#param {type: "list", itemType: "string"}
creator_criteria = visionai_v1.types.StringArray(txt_values=creator)
creator_filter = visionai_v1.Criteria(field="creator", text_array=creator_criteria)

criteria = [aspect_ratio_filter, creator_filter]

results = warehouse_client.search_index_endpoint(
    visionai_v1.SearchIndexEndpointRequest(
        index_endpoint=endpoint_name,
        image_query=visionai_v1.ImageQuery(
            input_image=image_content,
        ),
        criteria=criteria,
    ),
)

results_cnt = 0
asset_names = []
for r in results:
    asset_names.append(r.asset)
    results_cnt += 1
    if results_cnt >= MAX_RESULTS:
        break

uris = list(
    map(
        lambda asset_name: warehouse_client.generate_retrieval_url(
            visionai_v1.GenerateRetrievalUrlRequest(
                name=asset_name,
            )
        ).signed_uri,
        asset_names,
    )
)

print("Filtered search results:")
RenderImages(image_uris=uris)

### Adding metadata filters

In [None]:
IMAGE_GCS_FILE = "gs://cloud-samples-data/ai-platform/flowers/roses/14312910041_b747240d56_n.jpg"  # @param {type: "string"} example: gs://iwh_fishfood/sample-image.jpg
MAX_RESULTS = 10  # @param {type: "integer"} Set to 0 to allow all results.
IMAGE_FILE = "/tmp/sample-image.jpg"
!gsutil cp $IMAGE_GCS_FILE $IMAGE_FILE

with open(IMAGE_FILE, "rb") as f:
    image_content = f.read()
grid = GridspecLayout(1, 1)
grid[0, 0] = ipywidgets.Image(value=image_content, width=200)

print("Query image:")
display(grid)

aspect_ratios = ["1.3", "1.4"]  # # @param {type: "list", itemType: "string"}
aspect_ratio_criteria = visionai_v1.types.StringArray(txt_values=aspect_ratios)
aspect_ratio_filter = visionai_v1.Criteria(
    field="aspect-ratio", text_array=aspect_ratio_criteria
)

# Define creator filter criteria
creator = ["Saige Fuentes"]  # # @param {type: "list", itemType: "string"}
creator_criteria = visionai_v1.types.StringArray(txt_values=creator)
creator_filter = visionai_v1.Criteria(field="creator", text_array=creator_criteria)

criteria = [aspect_ratio_filter, creator_filter]

results = warehouse_client.search_index_endpoint(
    visionai_v1.SearchIndexEndpointRequest(
        index_endpoint=endpoint_name,
        image_query=visionai_v1.ImageQuery(
            input_image=image_content,
        ),
        criteria=criteria,
    ),
)

results_cnt = 0
asset_names = []
for r in results:
    asset_names.append(r.asset)
    results_cnt += 1
    if results_cnt >= MAX_RESULTS:
        break

uris = list(
    map(
        lambda asset_name: warehouse_client.generate_retrieval_url(
            visionai_v1.GenerateRetrievalUrlRequest(
                name=asset_name,
            )
        ).signed_uri,
        asset_names,
    )
)

print("Filtered search results:")
RenderImages(image_uris=uris)

## Cleaning up

In [None]:
CLEAN_UP = False  # @param {type: "boolean"}
if CLEAN_UP:
    undeploy_lro = warehouse_client.undeploy_index(
        visionai_v1.UndeployIndexRequest(
            index_endpoint=endpoint_name,
        )
    )
    print("Wait for undeploy operation:", undeploy_lro.operation)

    print(undeploy_lro.result(timeout=7200))

    delete_index_lro = warehouse_client.delete_index(
        visionai_v1.DeleteIndexRequest(
            name=index_name,
        )
    )
    print("Wait for delete operation:", delete_index_lro.operation)

    delete_endpoint_lro = warehouse_client.delete_index_endpoint(
        visionai_v1.DeleteIndexEndpointRequest(
            name=endpoint_name,
        )
    )
    print("Wait for delete operation:", delete_endpoint_lro.operation)

    while True:
        assets = warehouse_client.list_assets(
            visionai_v1.ListAssetsRequest(
                parent=corpus_name,
                page_size=1000,
            )
        )
        deletion_cnt = 0
        for a in assets:
            deletion_cnt += 1
            print("Deleting asset:", a.name)
            warehouse_client.delete_asset(
                visionai_v1.DeleteAssetRequest(
                    name=a.name,
                )
            )
            if deletion_cnt == 1000:
                break
        if deletion_cnt < 1000:
            break

    warehouse_client.delete_corpus(
        visionai_v1.DeleteCorpusRequest(
            name=corpus_name,
        )
    )