## Text-to-Image Retrieval using Online Endpoints and Indexes in Azure AI Search

This example shows how to perform text-to-image search with a Azure AI Search Index and a deployed `embeddings` type model.

### Task
The text-to-image retrieval task is to select from a collection of images those that are semantically related to a text query.
 
### Model
Models that can perform the `embeddings` task are tagged with `embeddings`. We will use the `OpenAI-CLIP-Image-Text-Embeddings-vit-base-patch32` model in this notebook. If you don't find a model that suits your scenario or domain, you can discover and [import models from HuggingFace hub](../../import/import_model_into_registry.ipynb) and then use them for inference. 

### Inference data
We will use the [fridgeObjects](https://automlsamplenotebookdata-adcuc7f7bqhhh8a4.b02.azurefd.net/image-classification/fridgeObjects.zip) dataset.


### Outline
1. Setup pre-requisites
2. Prepare data for inference
3. Deploy the model to an online endpoint real time inference
4. Create a search service and index
5. Populate the index with image embeddings
6. Query the index with text embeddings and visualize results

### 1. Setup pre-requisites
* Install dependencies
* Connect to AzureML Workspace. Learn more at [set up SDK authentication](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication?tabs=sdk). Replace  `<WORKSPACE_NAME>`, `<RESOURCE_GROUP>` and `<SUBSCRIPTION_ID>` below.
* Connect to `azureml` system registry

In [None]:
from azure.ai.ml import MLClient
from azure.identity import (
    DefaultAzureCredential,
    InteractiveBrowserCredential,
)
import time

try:
    credential = DefaultAzureCredential()
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    credential = InteractiveBrowserCredential()

try:
    workspace_ml_client = MLClient.from_config(credential)
    subscription_id = workspace_ml_client.subscription_id
    resource_group = workspace_ml_client.resource_group_name
    workspace_name = workspace_ml_client.workspace_name
except Exception as ex:
    print(ex)
    # Enter details of your AML workspace
    subscription_id = "<SUBSCRIPTION_ID>"
    resource_group = "<RESOURCE_GROUP>"
    workspace_name = "<AML_WORKSPACE_NAME>"
workspace_ml_client = MLClient(
    credential, subscription_id, resource_group, workspace_name
)

# The models are available in the AzureML system registry, "azureml"
registry_ml_client = MLClient(
    credential,
    subscription_id,
    resource_group,
    registry_name="azureml",
)
# Generating a unique timestamp that can be used for names and versions that need to be unique
timestamp = str(int(time.time()))

### 2. Prepare data for inference

We will use the [fridgeObjects](https://automlsamplenotebookdata-adcuc7f7bqhhh8a4.b02.azurefd.net/image-classification/fridgeObjects.zip) dataset for multi-class classification task. The fridge object dataset is stored in a directory. There are four different folders inside:
- /water_bottle
- /milk_bottle
- /carton
- /can


In [None]:
import os
import urllib
from zipfile import ZipFile

# Change to a different location if you prefer
dataset_parent_dir = "./data"

# create data folder if it doesnt exist.
os.makedirs(dataset_parent_dir, exist_ok=True)

# download data
download_url = "https://automlsamplenotebookdata-adcuc7f7bqhhh8a4.b02.azurefd.net/image-classification/fridgeObjects.zip"

# Extract current dataset name from dataset url
dataset_name = os.path.split(download_url)[-1].split(".")[0]
# Get dataset path for later use
dataset_dir = os.path.join(dataset_parent_dir, dataset_name)

In [None]:
# Get the data zip file path
data_file = os.path.join(dataset_parent_dir, f"{dataset_name}.zip")

# Download the dataset
urllib.request.urlretrieve(download_url, filename=data_file)

# extract files
with ZipFile(data_file, "r") as zip:
    print("extracting files...")
    zip.extractall(path=dataset_parent_dir)
    print("done")
# delete zip file
os.remove(data_file)

In [None]:
from IPython.display import Image

sample_image = os.path.join(dataset_dir, "milk_bottle", "99.jpg")
Image(filename=sample_image)

### 3. Deploy the model to an online endpoint for real time inference
Online endpoints give a durable REST API that can be used to integrate with applications that need to use the model.

In [None]:
model_name = "OpenAI-CLIP-Image-Text-Embeddings-vit-base-patch32"
foundation_model = registry_ml_client.models.get(name=model_name, label="latest")
print(
    f"\n\nUsing model name: {foundation_model.name}, version: {foundation_model.version}, id: {foundation_model.id} for inferencing"
)

In [None]:
import time
from azure.ai.ml.entities import (
    ManagedOnlineEndpoint,
    ManagedOnlineDeployment,
)

# Endpoint names need to be unique in a region, hence using timestamp to create unique endpoint name
timestamp = int(time.time())
online_endpoint_name = "clip-embeddings-" + str(timestamp)
# Create an online endpoint
endpoint = ManagedOnlineEndpoint(
    name=online_endpoint_name,
    description="Online endpoint for "
    + foundation_model.name
    + ", for image-text-embeddings task",
    auth_mode="key",
)
workspace_ml_client.begin_create_or_update(endpoint).wait()

In [None]:
from azure.ai.ml.entities import OnlineRequestSettings, ProbeSettings

deployment_name = "embeddings-mlflow-deploy"

# Create a deployment
demo_deployment = ManagedOnlineDeployment(
    name=deployment_name,
    endpoint_name=online_endpoint_name,
    model=foundation_model.id,
    instance_type="Standard_NC6s_v3",  # Use GPU instance type like Standard_DS3v2 for lower cost but slower inference
    instance_count=1,
    request_settings=OnlineRequestSettings(
        max_concurrent_requests_per_instance=1,
        request_timeout_ms=90000,
        max_queue_wait_ms=500,
    ),
    liveness_probe=ProbeSettings(
        failure_threshold=49,
        success_threshold=1,
        timeout=299,
        period=180,
        initial_delay=180,
    ),
    readiness_probe=ProbeSettings(
        failure_threshold=10,
        success_threshold=1,
        timeout=10,
        period=10,
        initial_delay=10,
    ),
)
workspace_ml_client.online_deployments.begin_create_or_update(demo_deployment).wait()
endpoint.traffic = {deployment_name: 100}
workspace_ml_client.begin_create_or_update(endpoint).result()

### 4. Create a search service and index

Follow instructions [here](https://learn.microsoft.com/en-us/azure/search/search-create-service-portal) to create a search service using the Azure Portal. Then, run the code below to create a search index.

In [None]:
SEARCH_SERVICE_NAME = "<SEARCH SERVICE NAME>"
SERVICE_ADMIN_KEY = "<admin key from the search service in Azure Portal>"

INDEX_NAME = "fridge-objects-index"
API_VERSION = "2023-07-01-Preview"
CREATE_INDEX_REQUEST_URL = "https://{search_service_name}.search.windows.net/indexes?api-version={api_version}".format(
    search_service_name=SEARCH_SERVICE_NAME, api_version=API_VERSION
)

In [None]:
import requests

create_request = {
    "name": INDEX_NAME,
    "fields": [
        {
            "name": "id",
            "type": "Edm.String",
            "key": True,
            "searchable": True,
            "retrievable": True,
            "filterable": True,
        },
        {
            "name": "filename",
            "type": "Edm.String",
            "searchable": True,
            "filterable": True,
            "sortable": True,
            "retrievable": True,
        },
        {
            "name": "imageEmbeddings",
            "type": "Collection(Edm.Single)",
            "searchable": True,
            "retrievable": True,
            "dimensions": 512,
            "vectorSearchConfiguration": "my-vector-config",
        },
    ],
    "vectorSearch": {
        "algorithmConfigurations": [
            {
                "name": "my-vector-config",
                "kind": "hnsw",
                "hnswParameters": {
                    "m": 4,
                    "efConstruction": 400,
                    "efSearch": 500,
                    "metric": "cosine",
                },
            }
        ]
    },
}
response = requests.post(
    CREATE_INDEX_REQUEST_URL,
    json=create_request,
    headers={"api-key": SERVICE_ADMIN_KEY},
)

### 5. Populate the index with image embeddings

Submit requests with image data to the online endpoint to get image embeddings. Add the image embeddings to the search index.

In [None]:
import json
import base64

_REQUEST_FILE_NAME = "request.json"


def read_image(image_path):
    with open(image_path, "rb") as f:
        return f.read()


def make_request_images(image_path):
    request_json = {
        "input_data": {
            "columns": ["image", "text"],
            "data": [[base64.encodebytes(read_image(image_path)).decode("utf-8"), ""]],
        }
    }

    with open(_REQUEST_FILE_NAME, "wt") as f:
        json.dump(request_json, f)

In [None]:
ADD_DATA_REQUEST_URL = "https://{search_service_name}.search.windows.net/indexes/{index_name}/docs/index?api-version={api_version}".format(
    search_service_name=SEARCH_SERVICE_NAME,
    index_name=INDEX_NAME,
    api_version=API_VERSION,
)

In [None]:
from tqdm.auto import tqdm

image_paths = [
    os.path.join(dp, f)
    for dp, dn, filenames in os.walk(dataset_dir)
    for f in filenames
    if os.path.splitext(f)[1] == ".jpg"
]

for idx, image_path in enumerate(tqdm(image_paths)):
    ID = idx
    FILENAME = image_path
    MAX_RETRIES = 3

    # get embedding from endpoint
    embedding_request = make_request_images(image_path)

    response = None
    request_failed = False
    IMAGE_EMBEDDING = None
    for r in range(MAX_RETRIES):
        try:
            response = workspace_ml_client.online_endpoints.invoke(
                endpoint_name=online_endpoint_name,
                deployment_name=deployment_name,
                request_file=_REQUEST_FILE_NAME,
            )
            response = json.loads(response)
            IMAGE_EMBEDDING = response[0]["image_features"]
            break
        except Exception as e:
            print(f"Unable to get embeddings for image {FILENAME}: {e}")
            print(response)
            if r == MAX_RETRIES - 1:
                print(f"attempt {r} failed, reached retry limit")
                request_failed = True
            else:
                print(f"attempt {r} failed, retrying")

    # add embedding to index
    if IMAGE_EMBEDDING:
        add_data_request = {
            "value": [
                {
                    "id": str(ID),
                    "filename": FILENAME,
                    "imageEmbeddings": IMAGE_EMBEDDING,
                    "@search.action": "upload",
                }
            ]
        }
        response = requests.post(
            ADD_DATA_REQUEST_URL,
            json=add_data_request,
            headers={"api-key": SERVICE_ADMIN_KEY},
        )

### 6. Query the index with text embeddings and visualize results

In [None]:
TEXT_QUERY = "a photo of a milk bottle"
K = 5  # number of results to retrieve

#### 6.1 Get the text embeddings for the query using the online endpoint

In [None]:
def make_request_text(text_sample):
    request_json = {
        "input_data": {
            "columns": ["image", "text"],
            "data": [["", text_sample]],
        }
    }

    with open(_REQUEST_FILE_NAME, "wt") as f:
        json.dump(request_json, f)


make_request_text(TEXT_QUERY)
response = workspace_ml_client.online_endpoints.invoke(
    endpoint_name=online_endpoint_name,
    deployment_name=deployment_name,
    request_file=_REQUEST_FILE_NAME,
)
response = json.loads(response)
QUERY_TEXT_EMBEDDING = response[0]["text_features"]

#### 6.2 Send the text embeddings as a query to the search index

In [None]:
QUERY_REQUEST_URL = "https://{search_service_name}.search.windows.net/indexes/{index_name}/docs/search?api-version={api_version}".format(
    search_service_name=SEARCH_SERVICE_NAME,
    index_name=INDEX_NAME,
    api_version=API_VERSION,
)


search_request = {
    "vectors": [{"value": QUERY_TEXT_EMBEDDING, "fields": "imageEmbeddings", "k": K}],
    "select": "filename",
}


response = requests.post(
    QUERY_REQUEST_URL, json=search_request, headers={"api-key": SERVICE_ADMIN_KEY}
)
neighbors = json.loads(response.text)["value"]

#### 6.3 Visualize Results

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from PIL import Image

K1, K2 = 3, 4


def make_pil_image(image_path):
    pil_image = Image.open(image_path)
    return pil_image


_, axes = plt.subplots(nrows=K1 + 1, ncols=K2, figsize=(64, 64))
for i in range(K1 + 1):
    for j in range(K2):
        axes[i, j].axis("off")

i, j = 0, 0

for neighbor in neighbors:
    pil_image = make_pil_image(neighbor["filename"])
    axes[i, j].imshow(np.asarray(pil_image), aspect="auto")
    axes[i, j].text(1, 1, "{:.4f}".format(neighbor["@search.score"]), fontsize=32)

    j += 1
    if j == K2:
        i += 1
        j = 0