supporting-blog-content/colpali/01_colpali.ipynb (419 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "3fef5a94-b06f-48ae-90e5-2f919d3352bd", "metadata": {}, "source": [ "This notebook shows how to ingest and search images using ColPali with Elasticsearch. Read our accompanying blog post on [ColPali in Elasticsearch](elastiacsearch-colpali-visual-document-search) for more context on this notebook. \n", "\n", "We will be using images from the [ViDoRe benchmark](https://huggingface.co/collections/vidore/vidore-benchmark-667173f98e70a1c0fa4db00d) as example data. \n", "\n", "The URL and API key for your Elasticsearch cluster are expected in a file `elastic.env` in this format: \n", "```\n", "ELASTIC_HOST=<cluster-url>\n", "ELASTIC_API_KEY=<api-key>\n", "```" ] }, { "cell_type": "code", "execution_count": 1, "id": "a1610e61-fbfe-4d7f-9109-601a0ccd0129", "metadata": {}, "outputs": [], "source": [ "!pip install -r requirements.txt\n", "from IPython.display import clear_output\n", "\n", "clear_output() # for less space usage." ] }, { "cell_type": "markdown", "id": "aec6865f-dc2d-4242-a568-2fbf94cf2201", "metadata": {}, "source": [ "First we load the sample data from huggingface and save it to disk." ] }, { "cell_type": "code", "execution_count": 2, "id": "baf63024-c058-4e1c-a170-6730bf2f2704", "metadata": { "ExecuteTime": { "end_time": "2025-03-02T09:16:41.680203Z", "start_time": "2025-03-02T09:14:00.648234Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6c0aa31eaa8546478c3e48fcc206dbd3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving images to disk: 0%| | 0/500 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from datasets import load_dataset\n", "from tqdm.notebook import tqdm\n", "import os\n", "\n", "DATASET_NAME = \"vidore/infovqa_test_subsampled\"\n", "DOCUMENT_DIR = \"searchlabs-colpali\"\n", "\n", "os.makedirs(DOCUMENT_DIR, exist_ok=True)\n", "dataset = load_dataset(DATASET_NAME, split=\"test\")\n", "\n", "for i, row in enumerate(tqdm(dataset, desc=\"Saving images to disk\")):\n", " image = row.get(\"image\")\n", " image_name = f\"image_{i}.jpg\"\n", " image_path = os.path.join(DOCUMENT_DIR, image_name)\n", " image.save(image_path)" ] }, { "cell_type": "markdown", "id": "da958778-42b3-438a-992c-172097d8d464", "metadata": {}, "source": [ "Here we load the ColPali model and define functions to generate vectors from images and text. " ] }, { "cell_type": "code", "execution_count": 3, "id": "32aad27e-afe7-4bbf-ad13-1be82c917a70", "metadata": { "ExecuteTime": { "end_time": "2025-03-02T09:16:41.681836Z", "start_time": "2025-03-02T09:14:10.977348Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0e608222306e4cd1a1b933ed248e74f8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import torch\n", "from PIL import Image\n", "from colpali_engine.models import ColPali, ColPaliProcessor\n", "\n", "model_name = \"vidore/colpali-v1.3\"\n", "model = ColPali.from_pretrained(\n", " \"vidore/colpali-v1.3\",\n", " torch_dtype=torch.float32,\n", " device_map=\"mps\", # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n", ").eval()\n", "\n", "col_pali_processor = ColPaliProcessor.from_pretrained(model_name)\n", "\n", "\n", "def create_col_pali_image_vectors(image_path: str) -> list:\n", " batch_images = col_pali_processor.process_images([Image.open(image_path)]).to(\n", " model.device\n", " )\n", "\n", " with torch.no_grad():\n", " return model(**batch_images).tolist()[0]\n", "\n", "\n", "def create_col_pali_query_vectors(query: str) -> list:\n", " queries = col_pali_processor.process_queries([query]).to(model.device)\n", " with torch.no_grad():\n", " return model(**queries).tolist()[0]" ] }, { "cell_type": "markdown", "id": "d12ea156-4e2b-4b84-983e-d0e63c9a6178", "metadata": {}, "source": [ "This is where we are going over all our images and creating our multi-vectors with the ColPali model. " ] }, { "cell_type": "code", "execution_count": 4, "id": "fcf55e15-6c4a-4003-b929-aab2931c2389", "metadata": { "ExecuteTime": { "end_time": "2025-03-02T09:16:41.682259Z", "start_time": "2025-03-02T09:14:22.244797Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "76b9d003d76d49d1b82b25d124deddeb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Create ColPali Vectors: 0%| | 0/500 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Saved 500 vector entries to disk\n" ] } ], "source": [ "import os\n", "import time\n", "import pickle\n", "\n", "images = [os.path.join(DOCUMENT_DIR, f) for f in os.listdir(DOCUMENT_DIR)]\n", "file_to_multi_vectors = {}\n", "\n", "for image_path in tqdm(images, desc=\"Create ColPali Vectors\"):\n", " file_name = os.path.basename(image_path)\n", " vectors_f32 = create_col_pali_image_vectors(image_path)\n", " file_to_multi_vectors[file_name] = vectors_f32\n", "\n", "with open(\"col_pali_vectors.pkl\", \"wb\") as f:\n", " pickle.dump(file_to_multi_vectors, f)\n", "\n", "print(f\"Saved {len(file_to_multi_vectors)} vector entries to disk\")" ] }, { "cell_type": "markdown", "id": "39512c53-2679-4ae0-a92a-31cb6374b60b", "metadata": {}, "source": [ "This is the new `rank_vectors` field type, where we will be saving our ColPali vectors. " ] }, { "cell_type": "code", "execution_count": 5, "id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8", "metadata": { "ExecuteTime": { "end_time": "2025-03-02T09:16:41.682532Z", "start_time": "2025-03-02T09:14:22.828689Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[INFO] Index 'searchlabs-colpali' already exists.\n" ] } ], "source": [ "from dotenv import load_dotenv\n", "from elasticsearch import Elasticsearch\n", "\n", "load_dotenv(\"elastic.env\")\n", "\n", "ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n", "ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n", "INDEX_NAME = \"searchlabs-colpali\"\n", "\n", "es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n", "\n", "mappings = {\"mappings\": {\"properties\": {\"col_pali_vectors\": {\"type\": \"rank_vectors\"}}}}\n", "\n", "if not es.indices.exists(index=INDEX_NAME):\n", " print(f\"[INFO] Creating index: {INDEX_NAME}\")\n", " es.indices.create(index=INDEX_NAME, body=mappings)\n", "else:\n", " print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n", "\n", "\n", "def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n", " for attempt in range(1, retries + 1):\n", " try:\n", " return es_client.index(index=index, id=doc_id, document=document)\n", " except Exception as e:\n", " if attempt < retries:\n", " wait_time = initial_backoff * (2 ** (attempt - 1))\n", " print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n", " time.sleep(wait_time)\n", " else:\n", " print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n", " raise" ] }, { "cell_type": "markdown", "id": "697f7b77-eb8f-430b-af3d-4a618c5cf086", "metadata": {}, "source": [ "Load all images back from disk, create the vectors for them and index them into Elasticsearch. " ] }, { "cell_type": "code", "execution_count": 6, "id": "adb4b59c-b36f-44d9-bee9-d20457630330", "metadata": { "ExecuteTime": { "end_time": "2025-03-02T09:16:41.682771Z", "start_time": "2025-03-02T09:14:24.511339Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2ee853046b1b4929b0784c41221ba03f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Index documents: 0%| | 0/500 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "with open(\"col_pali_vectors.pkl\", \"rb\") as f:\n", " file_to_multi_vectors = pickle.load(f)\n", "\n", "for file_name, vectors in tqdm(file_to_multi_vectors.items(), desc=\"Index documents\"):\n", " if es.exists(index=INDEX_NAME, id=file_name):\n", " continue\n", "\n", " index_document(\n", " es_client=es,\n", " index=INDEX_NAME,\n", " doc_id=file_name,\n", " document={\"col_pali_vectors\": vectors},\n", " )" ] }, { "cell_type": "markdown", "id": "3b556999-06b7-4856-a6bb-d72ad4929f62", "metadata": {}, "source": [ "Use the new `maxSimDotProduct` function to calculate the similarity between our query and the image vectors in Elasticsearch. " ] }, { "cell_type": "code", "execution_count": 7, "id": "8e322b23-b4bc-409d-9e00-2dab93f6a295", "metadata": { "ExecuteTime": { "end_time": "2025-03-02T09:16:41.683169Z", "start_time": "2025-03-02T09:14:24.521130Z" } }, "outputs": [ { "data": { "text/html": [ "<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_92.jpg\" alt=\"image_92.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import display, HTML\n", "import os\n", "\n", "query = \"What do companies use for recruiting?\"\n", "es_query = {\n", " \"_source\": False,\n", " \"query\": {\n", " \"script_score\": {\n", " \"query\": {\"match_all\": {}},\n", " \"script\": {\n", " \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n", " \"params\": {\"query_vector\": create_col_pali_query_vectors(query)},\n", " },\n", " }\n", " },\n", " \"size\": 5,\n", "}\n", "\n", "results = es.search(index=INDEX_NAME, body=es_query)\n", "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n", "\n", "html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n", "for image_id in image_ids:\n", " image_path = os.path.join(DOCUMENT_DIR, image_id)\n", " html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n", "html += \"</div>\"\n", "\n", "display(HTML(html))" ] }, { "cell_type": "code", "execution_count": null, "id": "16997bc1-ea8d-413b-a312-00f08fca1d0a", "metadata": {}, "outputs": [], "source": [ "# We kill the kernel forcefully to free up the memory from the ColPali model.\n", "print(\"Shutting down the kernel to free memory...\")\n", "import os\n", "\n", "os._exit(0)" ] } ], "metadata": { "kernelspec": { "display_name": "dependecy-test-colpali-blog", "language": "python", "name": "dependecy-test-colpali-blog" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.6" } }, "nbformat": 4, "nbformat_minor": 5 }