sdk/python/foundation-models/system/inference/image-text-embeddings/text-to-image-retrieval.ipynb (598 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Text-to-Image Retrieval using Online Endpoints and Indexes in Azure AI Search\n", "\n", "This example shows how to perform text-to-image search with a Azure AI Search Index and a deployed `embeddings` type model.\n", "\n", "### Task\n", "The text-to-image retrieval task is to select from a collection of images those that are semantically related to a text query.\n", " \n", "### Model\n", "Models that can perform the `embeddings` task are tagged with `embeddings`. We will use the `OpenAI-CLIP-Image-Text-Embeddings-vit-base-patch32` model in this notebook. If you don't find a model that suits your scenario or domain, you can discover and [import models from HuggingFace hub](../../import/import_model_into_registry.ipynb) and then use them for inference. \n", "\n", "### Inference data\n", "We will use the [fridgeObjects](https://automlsamplenotebookdata-adcuc7f7bqhhh8a4.b02.azurefd.net/image-classification/fridgeObjects.zip) dataset.\n", "\n", "\n", "### Outline\n", "1. Setup pre-requisites\n", "2. Prepare data for inference\n", "3. Deploy the model to an online endpoint real time inference\n", "4. Create a search service and index\n", "5. Populate the index with image embeddings\n", "6. Query the index with text embeddings and visualize results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Setup pre-requisites\n", "* Install dependencies\n", "* Connect to AzureML Workspace. Learn more at [set up SDK authentication](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication?tabs=sdk). Replace `<WORKSPACE_NAME>`, `<RESOURCE_GROUP>` and `<SUBSCRIPTION_ID>` below.\n", "* Connect to `azureml` system registry" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from azure.ai.ml import MLClient\n", "from azure.identity import (\n", " DefaultAzureCredential,\n", " InteractiveBrowserCredential,\n", ")\n", "import time\n", "\n", "try:\n", " credential = DefaultAzureCredential()\n", " credential.get_token(\"https://management.azure.com/.default\")\n", "except Exception as ex:\n", " credential = InteractiveBrowserCredential()\n", "\n", "try:\n", " workspace_ml_client = MLClient.from_config(credential)\n", " subscription_id = workspace_ml_client.subscription_id\n", " resource_group = workspace_ml_client.resource_group_name\n", " workspace_name = workspace_ml_client.workspace_name\n", "except Exception as ex:\n", " print(ex)\n", " # Enter details of your AML workspace\n", " subscription_id = \"<SUBSCRIPTION_ID>\"\n", " resource_group = \"<RESOURCE_GROUP>\"\n", " workspace_name = \"<AML_WORKSPACE_NAME>\"\n", "workspace_ml_client = MLClient(\n", " credential, subscription_id, resource_group, workspace_name\n", ")\n", "\n", "# The models are available in the AzureML system registry, \"azureml\"\n", "registry_ml_client = MLClient(\n", " credential,\n", " subscription_id,\n", " resource_group,\n", " registry_name=\"azureml\",\n", ")\n", "# Generating a unique timestamp that can be used for names and versions that need to be unique\n", "timestamp = str(int(time.time()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Prepare data for inference\n", "\n", "We will use the [fridgeObjects](https://automlsamplenotebookdata-adcuc7f7bqhhh8a4.b02.azurefd.net/image-classification/fridgeObjects.zip) dataset for multi-class classification task. The fridge object dataset is stored in a directory. There are four different folders inside:\n", "- /water_bottle\n", "- /milk_bottle\n", "- /carton\n", "- /can\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import urllib\n", "from zipfile import ZipFile\n", "\n", "# Change to a different location if you prefer\n", "dataset_parent_dir = \"./data\"\n", "\n", "# create data folder if it doesnt exist.\n", "os.makedirs(dataset_parent_dir, exist_ok=True)\n", "\n", "# download data\n", "download_url = \"https://automlsamplenotebookdata-adcuc7f7bqhhh8a4.b02.azurefd.net/image-classification/fridgeObjects.zip\"\n", "\n", "# Extract current dataset name from dataset url\n", "dataset_name = os.path.split(download_url)[-1].split(\".\")[0]\n", "# Get dataset path for later use\n", "dataset_dir = os.path.join(dataset_parent_dir, dataset_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Get the data zip file path\n", "data_file = os.path.join(dataset_parent_dir, f\"{dataset_name}.zip\")\n", "\n", "# Download the dataset\n", "urllib.request.urlretrieve(download_url, filename=data_file)\n", "\n", "# extract files\n", "with ZipFile(data_file, \"r\") as zip:\n", " print(\"extracting files...\")\n", " zip.extractall(path=dataset_parent_dir)\n", " print(\"done\")\n", "# delete zip file\n", "os.remove(data_file)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import Image\n", "\n", "sample_image = os.path.join(dataset_dir, \"milk_bottle\", \"99.jpg\")\n", "Image(filename=sample_image)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Deploy the model to an online endpoint for real time inference\n", "Online endpoints give a durable REST API that can be used to integrate with applications that need to use the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_name = \"OpenAI-CLIP-Image-Text-Embeddings-vit-base-patch32\"\n", "foundation_model = registry_ml_client.models.get(name=model_name, label=\"latest\")\n", "print(\n", " f\"\\n\\nUsing model name: {foundation_model.name}, version: {foundation_model.version}, id: {foundation_model.id} for inferencing\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "from azure.ai.ml.entities import (\n", " ManagedOnlineEndpoint,\n", " ManagedOnlineDeployment,\n", ")\n", "\n", "# Endpoint names need to be unique in a region, hence using timestamp to create unique endpoint name\n", "timestamp = int(time.time())\n", "online_endpoint_name = \"clip-embeddings-\" + str(timestamp)\n", "# Create an online endpoint\n", "endpoint = ManagedOnlineEndpoint(\n", " name=online_endpoint_name,\n", " description=\"Online endpoint for \"\n", " + foundation_model.name\n", " + \", for image-text-embeddings task\",\n", " auth_mode=\"key\",\n", ")\n", "workspace_ml_client.begin_create_or_update(endpoint).wait()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from azure.ai.ml.entities import OnlineRequestSettings, ProbeSettings\n", "\n", "deployment_name = \"embeddings-mlflow-deploy\"\n", "\n", "# Create a deployment\n", "demo_deployment = ManagedOnlineDeployment(\n", " name=deployment_name,\n", " endpoint_name=online_endpoint_name,\n", " model=foundation_model.id,\n", " instance_type=\"Standard_NC6s_v3\", # Use GPU instance type like Standard_DS3v2 for lower cost but slower inference\n", " instance_count=1,\n", " request_settings=OnlineRequestSettings(\n", " max_concurrent_requests_per_instance=1,\n", " request_timeout_ms=90000,\n", " max_queue_wait_ms=500,\n", " ),\n", " liveness_probe=ProbeSettings(\n", " failure_threshold=49,\n", " success_threshold=1,\n", " timeout=299,\n", " period=180,\n", " initial_delay=180,\n", " ),\n", " readiness_probe=ProbeSettings(\n", " failure_threshold=10,\n", " success_threshold=1,\n", " timeout=10,\n", " period=10,\n", " initial_delay=10,\n", " ),\n", ")\n", "workspace_ml_client.online_deployments.begin_create_or_update(demo_deployment).wait()\n", "endpoint.traffic = {deployment_name: 100}\n", "workspace_ml_client.begin_create_or_update(endpoint).result()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Create a search service and index" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Follow instructions [here](https://learn.microsoft.com/en-us/azure/search/search-create-service-portal) to create a search service using the Azure Portal. Then, run the code below to create a search index." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "SEARCH_SERVICE_NAME = \"<SEARCH SERVICE NAME>\"\n", "SERVICE_ADMIN_KEY = \"<admin key from the search service in Azure Portal>\"\n", "\n", "INDEX_NAME = \"fridge-objects-index\"\n", "API_VERSION = \"2023-07-01-Preview\"\n", "CREATE_INDEX_REQUEST_URL = \"https://{search_service_name}.search.windows.net/indexes?api-version={api_version}\".format(\n", " search_service_name=SEARCH_SERVICE_NAME, api_version=API_VERSION\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import requests\n", "\n", "create_request = {\n", " \"name\": INDEX_NAME,\n", " \"fields\": [\n", " {\n", " \"name\": \"id\",\n", " \"type\": \"Edm.String\",\n", " \"key\": True,\n", " \"searchable\": True,\n", " \"retrievable\": True,\n", " \"filterable\": True,\n", " },\n", " {\n", " \"name\": \"filename\",\n", " \"type\": \"Edm.String\",\n", " \"searchable\": True,\n", " \"filterable\": True,\n", " \"sortable\": True,\n", " \"retrievable\": True,\n", " },\n", " {\n", " \"name\": \"imageEmbeddings\",\n", " \"type\": \"Collection(Edm.Single)\",\n", " \"searchable\": True,\n", " \"retrievable\": True,\n", " \"dimensions\": 512,\n", " \"vectorSearchConfiguration\": \"my-vector-config\",\n", " },\n", " ],\n", " \"vectorSearch\": {\n", " \"algorithmConfigurations\": [\n", " {\n", " \"name\": \"my-vector-config\",\n", " \"kind\": \"hnsw\",\n", " \"hnswParameters\": {\n", " \"m\": 4,\n", " \"efConstruction\": 400,\n", " \"efSearch\": 500,\n", " \"metric\": \"cosine\",\n", " },\n", " }\n", " ]\n", " },\n", "}\n", "response = requests.post(\n", " CREATE_INDEX_REQUEST_URL,\n", " json=create_request,\n", " headers={\"api-key\": SERVICE_ADMIN_KEY},\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. Populate the index with image embeddings\n", "\n", "Submit requests with image data to the online endpoint to get image embeddings. Add the image embeddings to the search index." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import base64\n", "\n", "_REQUEST_FILE_NAME = \"request.json\"\n", "\n", "\n", "def read_image(image_path):\n", " with open(image_path, \"rb\") as f:\n", " return f.read()\n", "\n", "\n", "def make_request_images(image_path):\n", " request_json = {\n", " \"input_data\": {\n", " \"columns\": [\"image\", \"text\"],\n", " \"data\": [[base64.encodebytes(read_image(image_path)).decode(\"utf-8\"), \"\"]],\n", " }\n", " }\n", "\n", " with open(_REQUEST_FILE_NAME, \"wt\") as f:\n", " json.dump(request_json, f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ADD_DATA_REQUEST_URL = \"https://{search_service_name}.search.windows.net/indexes/{index_name}/docs/index?api-version={api_version}\".format(\n", " search_service_name=SEARCH_SERVICE_NAME,\n", " index_name=INDEX_NAME,\n", " api_version=API_VERSION,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tqdm.auto import tqdm\n", "\n", "image_paths = [\n", " os.path.join(dp, f)\n", " for dp, dn, filenames in os.walk(dataset_dir)\n", " for f in filenames\n", " if os.path.splitext(f)[1] == \".jpg\"\n", "]\n", "\n", "for idx, image_path in enumerate(tqdm(image_paths)):\n", " ID = idx\n", " FILENAME = image_path\n", " MAX_RETRIES = 3\n", "\n", " # get embedding from endpoint\n", " embedding_request = make_request_images(image_path)\n", "\n", " response = None\n", " request_failed = False\n", " IMAGE_EMBEDDING = None\n", " for r in range(MAX_RETRIES):\n", " try:\n", " response = workspace_ml_client.online_endpoints.invoke(\n", " endpoint_name=online_endpoint_name,\n", " deployment_name=deployment_name,\n", " request_file=_REQUEST_FILE_NAME,\n", " )\n", " response = json.loads(response)\n", " IMAGE_EMBEDDING = response[0][\"image_features\"]\n", " break\n", " except Exception as e:\n", " print(f\"Unable to get embeddings for image {FILENAME}: {e}\")\n", " print(response)\n", " if r == MAX_RETRIES - 1:\n", " print(f\"attempt {r} failed, reached retry limit\")\n", " request_failed = True\n", " else:\n", " print(f\"attempt {r} failed, retrying\")\n", "\n", " # add embedding to index\n", " if IMAGE_EMBEDDING:\n", " add_data_request = {\n", " \"value\": [\n", " {\n", " \"id\": str(ID),\n", " \"filename\": FILENAME,\n", " \"imageEmbeddings\": IMAGE_EMBEDDING,\n", " \"@search.action\": \"upload\",\n", " }\n", " ]\n", " }\n", " response = requests.post(\n", " ADD_DATA_REQUEST_URL,\n", " json=add_data_request,\n", " headers={\"api-key\": SERVICE_ADMIN_KEY},\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6. Query the index with text embeddings and visualize results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "TEXT_QUERY = \"a photo of a milk bottle\"\n", "K = 5 # number of results to retrieve" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 6.1 Get the text embeddings for the query using the online endpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def make_request_text(text_sample):\n", " request_json = {\n", " \"input_data\": {\n", " \"columns\": [\"image\", \"text\"],\n", " \"data\": [[\"\", text_sample]],\n", " }\n", " }\n", "\n", " with open(_REQUEST_FILE_NAME, \"wt\") as f:\n", " json.dump(request_json, f)\n", "\n", "\n", "make_request_text(TEXT_QUERY)\n", "response = workspace_ml_client.online_endpoints.invoke(\n", " endpoint_name=online_endpoint_name,\n", " deployment_name=deployment_name,\n", " request_file=_REQUEST_FILE_NAME,\n", ")\n", "response = json.loads(response)\n", "QUERY_TEXT_EMBEDDING = response[0][\"text_features\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 6.2 Send the text embeddings as a query to the search index" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "QUERY_REQUEST_URL = \"https://{search_service_name}.search.windows.net/indexes/{index_name}/docs/search?api-version={api_version}\".format(\n", " search_service_name=SEARCH_SERVICE_NAME,\n", " index_name=INDEX_NAME,\n", " api_version=API_VERSION,\n", ")\n", "\n", "\n", "search_request = {\n", " \"vectors\": [{\"value\": QUERY_TEXT_EMBEDDING, \"fields\": \"imageEmbeddings\", \"k\": K}],\n", " \"select\": \"filename\",\n", "}\n", "\n", "\n", "response = requests.post(\n", " QUERY_REQUEST_URL, json=search_request, headers={\"api-key\": SERVICE_ADMIN_KEY}\n", ")\n", "neighbors = json.loads(response.text)[\"value\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 6.3 Visualize Results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from PIL import Image\n", "\n", "K1, K2 = 3, 4\n", "\n", "\n", "def make_pil_image(image_path):\n", " pil_image = Image.open(image_path)\n", " return pil_image\n", "\n", "\n", "_, axes = plt.subplots(nrows=K1 + 1, ncols=K2, figsize=(64, 64))\n", "for i in range(K1 + 1):\n", " for j in range(K2):\n", " axes[i, j].axis(\"off\")\n", "\n", "i, j = 0, 0\n", "\n", "for neighbor in neighbors:\n", " pil_image = make_pil_image(neighbor[\"filename\"])\n", " axes[i, j].imshow(np.asarray(pil_image), aspect=\"auto\")\n", " axes[i, j].text(1, 1, \"{:.4f}\".format(neighbor[\"@search.score\"]), fontsize=32)\n", "\n", " j += 1\n", " if j == K2:\n", " i += 1\n", " j = 0" ] } ], "metadata": { "kernelspec": { "display_name": "rc_133", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 2 }