sdk/python/foundation-models/system/inference/image-text-embeddings/text-to-image-retrieval.ipynb (598 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Text-to-Image Retrieval using Online Endpoints and Indexes in Azure AI Search\n",
"\n",
"This example shows how to perform text-to-image search with a Azure AI Search Index and a deployed `embeddings` type model.\n",
"\n",
"### Task\n",
"The text-to-image retrieval task is to select from a collection of images those that are semantically related to a text query.\n",
" \n",
"### Model\n",
"Models that can perform the `embeddings` task are tagged with `embeddings`. We will use the `OpenAI-CLIP-Image-Text-Embeddings-vit-base-patch32` model in this notebook. If you don't find a model that suits your scenario or domain, you can discover and [import models from HuggingFace hub](../../import/import_model_into_registry.ipynb) and then use them for inference. \n",
"\n",
"### Inference data\n",
"We will use the [fridgeObjects](https://automlsamplenotebookdata-adcuc7f7bqhhh8a4.b02.azurefd.net/image-classification/fridgeObjects.zip) dataset.\n",
"\n",
"\n",
"### Outline\n",
"1. Setup pre-requisites\n",
"2. Prepare data for inference\n",
"3. Deploy the model to an online endpoint real time inference\n",
"4. Create a search service and index\n",
"5. Populate the index with image embeddings\n",
"6. Query the index with text embeddings and visualize results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Setup pre-requisites\n",
"* Install dependencies\n",
"* Connect to AzureML Workspace. Learn more at [set up SDK authentication](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication?tabs=sdk). Replace `<WORKSPACE_NAME>`, `<RESOURCE_GROUP>` and `<SUBSCRIPTION_ID>` below.\n",
"* Connect to `azureml` system registry"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from azure.ai.ml import MLClient\n",
"from azure.identity import (\n",
" DefaultAzureCredential,\n",
" InteractiveBrowserCredential,\n",
")\n",
"import time\n",
"\n",
"try:\n",
" credential = DefaultAzureCredential()\n",
" credential.get_token(\"https://management.azure.com/.default\")\n",
"except Exception as ex:\n",
" credential = InteractiveBrowserCredential()\n",
"\n",
"try:\n",
" workspace_ml_client = MLClient.from_config(credential)\n",
" subscription_id = workspace_ml_client.subscription_id\n",
" resource_group = workspace_ml_client.resource_group_name\n",
" workspace_name = workspace_ml_client.workspace_name\n",
"except Exception as ex:\n",
" print(ex)\n",
" # Enter details of your AML workspace\n",
" subscription_id = \"<SUBSCRIPTION_ID>\"\n",
" resource_group = \"<RESOURCE_GROUP>\"\n",
" workspace_name = \"<AML_WORKSPACE_NAME>\"\n",
"workspace_ml_client = MLClient(\n",
" credential, subscription_id, resource_group, workspace_name\n",
")\n",
"\n",
"# The models are available in the AzureML system registry, \"azureml\"\n",
"registry_ml_client = MLClient(\n",
" credential,\n",
" subscription_id,\n",
" resource_group,\n",
" registry_name=\"azureml\",\n",
")\n",
"# Generating a unique timestamp that can be used for names and versions that need to be unique\n",
"timestamp = str(int(time.time()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Prepare data for inference\n",
"\n",
"We will use the [fridgeObjects](https://automlsamplenotebookdata-adcuc7f7bqhhh8a4.b02.azurefd.net/image-classification/fridgeObjects.zip) dataset for multi-class classification task. The fridge object dataset is stored in a directory. There are four different folders inside:\n",
"- /water_bottle\n",
"- /milk_bottle\n",
"- /carton\n",
"- /can\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import urllib\n",
"from zipfile import ZipFile\n",
"\n",
"# Change to a different location if you prefer\n",
"dataset_parent_dir = \"./data\"\n",
"\n",
"# create data folder if it doesnt exist.\n",
"os.makedirs(dataset_parent_dir, exist_ok=True)\n",
"\n",
"# download data\n",
"download_url = \"https://automlsamplenotebookdata-adcuc7f7bqhhh8a4.b02.azurefd.net/image-classification/fridgeObjects.zip\"\n",
"\n",
"# Extract current dataset name from dataset url\n",
"dataset_name = os.path.split(download_url)[-1].split(\".\")[0]\n",
"# Get dataset path for later use\n",
"dataset_dir = os.path.join(dataset_parent_dir, dataset_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get the data zip file path\n",
"data_file = os.path.join(dataset_parent_dir, f\"{dataset_name}.zip\")\n",
"\n",
"# Download the dataset\n",
"urllib.request.urlretrieve(download_url, filename=data_file)\n",
"\n",
"# extract files\n",
"with ZipFile(data_file, \"r\") as zip:\n",
" print(\"extracting files...\")\n",
" zip.extractall(path=dataset_parent_dir)\n",
" print(\"done\")\n",
"# delete zip file\n",
"os.remove(data_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Image\n",
"\n",
"sample_image = os.path.join(dataset_dir, \"milk_bottle\", \"99.jpg\")\n",
"Image(filename=sample_image)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Deploy the model to an online endpoint for real time inference\n",
"Online endpoints give a durable REST API that can be used to integrate with applications that need to use the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_name = \"OpenAI-CLIP-Image-Text-Embeddings-vit-base-patch32\"\n",
"foundation_model = registry_ml_client.models.get(name=model_name, label=\"latest\")\n",
"print(\n",
" f\"\\n\\nUsing model name: {foundation_model.name}, version: {foundation_model.version}, id: {foundation_model.id} for inferencing\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"from azure.ai.ml.entities import (\n",
" ManagedOnlineEndpoint,\n",
" ManagedOnlineDeployment,\n",
")\n",
"\n",
"# Endpoint names need to be unique in a region, hence using timestamp to create unique endpoint name\n",
"timestamp = int(time.time())\n",
"online_endpoint_name = \"clip-embeddings-\" + str(timestamp)\n",
"# Create an online endpoint\n",
"endpoint = ManagedOnlineEndpoint(\n",
" name=online_endpoint_name,\n",
" description=\"Online endpoint for \"\n",
" + foundation_model.name\n",
" + \", for image-text-embeddings task\",\n",
" auth_mode=\"key\",\n",
")\n",
"workspace_ml_client.begin_create_or_update(endpoint).wait()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from azure.ai.ml.entities import OnlineRequestSettings, ProbeSettings\n",
"\n",
"deployment_name = \"embeddings-mlflow-deploy\"\n",
"\n",
"# Create a deployment\n",
"demo_deployment = ManagedOnlineDeployment(\n",
" name=deployment_name,\n",
" endpoint_name=online_endpoint_name,\n",
" model=foundation_model.id,\n",
" instance_type=\"Standard_NC6s_v3\", # Use GPU instance type like Standard_DS3v2 for lower cost but slower inference\n",
" instance_count=1,\n",
" request_settings=OnlineRequestSettings(\n",
" max_concurrent_requests_per_instance=1,\n",
" request_timeout_ms=90000,\n",
" max_queue_wait_ms=500,\n",
" ),\n",
" liveness_probe=ProbeSettings(\n",
" failure_threshold=49,\n",
" success_threshold=1,\n",
" timeout=299,\n",
" period=180,\n",
" initial_delay=180,\n",
" ),\n",
" readiness_probe=ProbeSettings(\n",
" failure_threshold=10,\n",
" success_threshold=1,\n",
" timeout=10,\n",
" period=10,\n",
" initial_delay=10,\n",
" ),\n",
")\n",
"workspace_ml_client.online_deployments.begin_create_or_update(demo_deployment).wait()\n",
"endpoint.traffic = {deployment_name: 100}\n",
"workspace_ml_client.begin_create_or_update(endpoint).result()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Create a search service and index"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Follow instructions [here](https://learn.microsoft.com/en-us/azure/search/search-create-service-portal) to create a search service using the Azure Portal. Then, run the code below to create a search index."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"SEARCH_SERVICE_NAME = \"<SEARCH SERVICE NAME>\"\n",
"SERVICE_ADMIN_KEY = \"<admin key from the search service in Azure Portal>\"\n",
"\n",
"INDEX_NAME = \"fridge-objects-index\"\n",
"API_VERSION = \"2023-07-01-Preview\"\n",
"CREATE_INDEX_REQUEST_URL = \"https://{search_service_name}.search.windows.net/indexes?api-version={api_version}\".format(\n",
" search_service_name=SEARCH_SERVICE_NAME, api_version=API_VERSION\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"\n",
"create_request = {\n",
" \"name\": INDEX_NAME,\n",
" \"fields\": [\n",
" {\n",
" \"name\": \"id\",\n",
" \"type\": \"Edm.String\",\n",
" \"key\": True,\n",
" \"searchable\": True,\n",
" \"retrievable\": True,\n",
" \"filterable\": True,\n",
" },\n",
" {\n",
" \"name\": \"filename\",\n",
" \"type\": \"Edm.String\",\n",
" \"searchable\": True,\n",
" \"filterable\": True,\n",
" \"sortable\": True,\n",
" \"retrievable\": True,\n",
" },\n",
" {\n",
" \"name\": \"imageEmbeddings\",\n",
" \"type\": \"Collection(Edm.Single)\",\n",
" \"searchable\": True,\n",
" \"retrievable\": True,\n",
" \"dimensions\": 512,\n",
" \"vectorSearchConfiguration\": \"my-vector-config\",\n",
" },\n",
" ],\n",
" \"vectorSearch\": {\n",
" \"algorithmConfigurations\": [\n",
" {\n",
" \"name\": \"my-vector-config\",\n",
" \"kind\": \"hnsw\",\n",
" \"hnswParameters\": {\n",
" \"m\": 4,\n",
" \"efConstruction\": 400,\n",
" \"efSearch\": 500,\n",
" \"metric\": \"cosine\",\n",
" },\n",
" }\n",
" ]\n",
" },\n",
"}\n",
"response = requests.post(\n",
" CREATE_INDEX_REQUEST_URL,\n",
" json=create_request,\n",
" headers={\"api-key\": SERVICE_ADMIN_KEY},\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5. Populate the index with image embeddings\n",
"\n",
"Submit requests with image data to the online endpoint to get image embeddings. Add the image embeddings to the search index."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import base64\n",
"\n",
"_REQUEST_FILE_NAME = \"request.json\"\n",
"\n",
"\n",
"def read_image(image_path):\n",
" with open(image_path, \"rb\") as f:\n",
" return f.read()\n",
"\n",
"\n",
"def make_request_images(image_path):\n",
" request_json = {\n",
" \"input_data\": {\n",
" \"columns\": [\"image\", \"text\"],\n",
" \"data\": [[base64.encodebytes(read_image(image_path)).decode(\"utf-8\"), \"\"]],\n",
" }\n",
" }\n",
"\n",
" with open(_REQUEST_FILE_NAME, \"wt\") as f:\n",
" json.dump(request_json, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ADD_DATA_REQUEST_URL = \"https://{search_service_name}.search.windows.net/indexes/{index_name}/docs/index?api-version={api_version}\".format(\n",
" search_service_name=SEARCH_SERVICE_NAME,\n",
" index_name=INDEX_NAME,\n",
" api_version=API_VERSION,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tqdm.auto import tqdm\n",
"\n",
"image_paths = [\n",
" os.path.join(dp, f)\n",
" for dp, dn, filenames in os.walk(dataset_dir)\n",
" for f in filenames\n",
" if os.path.splitext(f)[1] == \".jpg\"\n",
"]\n",
"\n",
"for idx, image_path in enumerate(tqdm(image_paths)):\n",
" ID = idx\n",
" FILENAME = image_path\n",
" MAX_RETRIES = 3\n",
"\n",
" # get embedding from endpoint\n",
" embedding_request = make_request_images(image_path)\n",
"\n",
" response = None\n",
" request_failed = False\n",
" IMAGE_EMBEDDING = None\n",
" for r in range(MAX_RETRIES):\n",
" try:\n",
" response = workspace_ml_client.online_endpoints.invoke(\n",
" endpoint_name=online_endpoint_name,\n",
" deployment_name=deployment_name,\n",
" request_file=_REQUEST_FILE_NAME,\n",
" )\n",
" response = json.loads(response)\n",
" IMAGE_EMBEDDING = response[0][\"image_features\"]\n",
" break\n",
" except Exception as e:\n",
" print(f\"Unable to get embeddings for image {FILENAME}: {e}\")\n",
" print(response)\n",
" if r == MAX_RETRIES - 1:\n",
" print(f\"attempt {r} failed, reached retry limit\")\n",
" request_failed = True\n",
" else:\n",
" print(f\"attempt {r} failed, retrying\")\n",
"\n",
" # add embedding to index\n",
" if IMAGE_EMBEDDING:\n",
" add_data_request = {\n",
" \"value\": [\n",
" {\n",
" \"id\": str(ID),\n",
" \"filename\": FILENAME,\n",
" \"imageEmbeddings\": IMAGE_EMBEDDING,\n",
" \"@search.action\": \"upload\",\n",
" }\n",
" ]\n",
" }\n",
" response = requests.post(\n",
" ADD_DATA_REQUEST_URL,\n",
" json=add_data_request,\n",
" headers={\"api-key\": SERVICE_ADMIN_KEY},\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6. Query the index with text embeddings and visualize results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"TEXT_QUERY = \"a photo of a milk bottle\"\n",
"K = 5 # number of results to retrieve"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 6.1 Get the text embeddings for the query using the online endpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def make_request_text(text_sample):\n",
" request_json = {\n",
" \"input_data\": {\n",
" \"columns\": [\"image\", \"text\"],\n",
" \"data\": [[\"\", text_sample]],\n",
" }\n",
" }\n",
"\n",
" with open(_REQUEST_FILE_NAME, \"wt\") as f:\n",
" json.dump(request_json, f)\n",
"\n",
"\n",
"make_request_text(TEXT_QUERY)\n",
"response = workspace_ml_client.online_endpoints.invoke(\n",
" endpoint_name=online_endpoint_name,\n",
" deployment_name=deployment_name,\n",
" request_file=_REQUEST_FILE_NAME,\n",
")\n",
"response = json.loads(response)\n",
"QUERY_TEXT_EMBEDDING = response[0][\"text_features\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 6.2 Send the text embeddings as a query to the search index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"QUERY_REQUEST_URL = \"https://{search_service_name}.search.windows.net/indexes/{index_name}/docs/search?api-version={api_version}\".format(\n",
" search_service_name=SEARCH_SERVICE_NAME,\n",
" index_name=INDEX_NAME,\n",
" api_version=API_VERSION,\n",
")\n",
"\n",
"\n",
"search_request = {\n",
" \"vectors\": [{\"value\": QUERY_TEXT_EMBEDDING, \"fields\": \"imageEmbeddings\", \"k\": K}],\n",
" \"select\": \"filename\",\n",
"}\n",
"\n",
"\n",
"response = requests.post(\n",
" QUERY_REQUEST_URL, json=search_request, headers={\"api-key\": SERVICE_ADMIN_KEY}\n",
")\n",
"neighbors = json.loads(response.text)[\"value\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 6.3 Visualize Results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from PIL import Image\n",
"\n",
"K1, K2 = 3, 4\n",
"\n",
"\n",
"def make_pil_image(image_path):\n",
" pil_image = Image.open(image_path)\n",
" return pil_image\n",
"\n",
"\n",
"_, axes = plt.subplots(nrows=K1 + 1, ncols=K2, figsize=(64, 64))\n",
"for i in range(K1 + 1):\n",
" for j in range(K2):\n",
" axes[i, j].axis(\"off\")\n",
"\n",
"i, j = 0, 0\n",
"\n",
"for neighbor in neighbors:\n",
" pil_image = make_pil_image(neighbor[\"filename\"])\n",
" axes[i, j].imshow(np.asarray(pil_image), aspect=\"auto\")\n",
" axes[i, j].text(1, 1, \"{:.4f}\".format(neighbor[\"@search.score\"]), fontsize=32)\n",
"\n",
" j += 1\n",
" if j == K2:\n",
" i += 1\n",
" j = 0"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "rc_133",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}