supporting-blog-content/colpali/01_colpali.ipynb (419 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "3fef5a94-b06f-48ae-90e5-2f919d3352bd",
"metadata": {},
"source": [
"This notebook shows how to ingest and search images using ColPali with Elasticsearch. Read our accompanying blog post on [ColPali in Elasticsearch](elastiacsearch-colpali-visual-document-search) for more context on this notebook. \n",
"\n",
"We will be using images from the [ViDoRe benchmark](https://huggingface.co/collections/vidore/vidore-benchmark-667173f98e70a1c0fa4db00d) as example data. \n",
"\n",
"The URL and API key for your Elasticsearch cluster are expected in a file `elastic.env` in this format: \n",
"```\n",
"ELASTIC_HOST=<cluster-url>\n",
"ELASTIC_API_KEY=<api-key>\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a1610e61-fbfe-4d7f-9109-601a0ccd0129",
"metadata": {},
"outputs": [],
"source": [
"!pip install -r requirements.txt\n",
"from IPython.display import clear_output\n",
"\n",
"clear_output() # for less space usage."
]
},
{
"cell_type": "markdown",
"id": "aec6865f-dc2d-4242-a568-2fbf94cf2201",
"metadata": {},
"source": [
"First we load the sample data from huggingface and save it to disk."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "baf63024-c058-4e1c-a170-6730bf2f2704",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T09:16:41.680203Z",
"start_time": "2025-03-02T09:14:00.648234Z"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6c0aa31eaa8546478c3e48fcc206dbd3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Saving images to disk: 0%| | 0/500 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset\n",
"from tqdm.notebook import tqdm\n",
"import os\n",
"\n",
"DATASET_NAME = \"vidore/infovqa_test_subsampled\"\n",
"DOCUMENT_DIR = \"searchlabs-colpali\"\n",
"\n",
"os.makedirs(DOCUMENT_DIR, exist_ok=True)\n",
"dataset = load_dataset(DATASET_NAME, split=\"test\")\n",
"\n",
"for i, row in enumerate(tqdm(dataset, desc=\"Saving images to disk\")):\n",
" image = row.get(\"image\")\n",
" image_name = f\"image_{i}.jpg\"\n",
" image_path = os.path.join(DOCUMENT_DIR, image_name)\n",
" image.save(image_path)"
]
},
{
"cell_type": "markdown",
"id": "da958778-42b3-438a-992c-172097d8d464",
"metadata": {},
"source": [
"Here we load the ColPali model and define functions to generate vectors from images and text. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "32aad27e-afe7-4bbf-ad13-1be82c917a70",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T09:16:41.681836Z",
"start_time": "2025-03-02T09:14:10.977348Z"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0e608222306e4cd1a1b933ed248e74f8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch\n",
"from PIL import Image\n",
"from colpali_engine.models import ColPali, ColPaliProcessor\n",
"\n",
"model_name = \"vidore/colpali-v1.3\"\n",
"model = ColPali.from_pretrained(\n",
" \"vidore/colpali-v1.3\",\n",
" torch_dtype=torch.float32,\n",
" device_map=\"mps\", # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n",
").eval()\n",
"\n",
"col_pali_processor = ColPaliProcessor.from_pretrained(model_name)\n",
"\n",
"\n",
"def create_col_pali_image_vectors(image_path: str) -> list:\n",
" batch_images = col_pali_processor.process_images([Image.open(image_path)]).to(\n",
" model.device\n",
" )\n",
"\n",
" with torch.no_grad():\n",
" return model(**batch_images).tolist()[0]\n",
"\n",
"\n",
"def create_col_pali_query_vectors(query: str) -> list:\n",
" queries = col_pali_processor.process_queries([query]).to(model.device)\n",
" with torch.no_grad():\n",
" return model(**queries).tolist()[0]"
]
},
{
"cell_type": "markdown",
"id": "d12ea156-4e2b-4b84-983e-d0e63c9a6178",
"metadata": {},
"source": [
"This is where we are going over all our images and creating our multi-vectors with the ColPali model. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fcf55e15-6c4a-4003-b929-aab2931c2389",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T09:16:41.682259Z",
"start_time": "2025-03-02T09:14:22.244797Z"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "76b9d003d76d49d1b82b25d124deddeb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Create ColPali Vectors: 0%| | 0/500 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved 500 vector entries to disk\n"
]
}
],
"source": [
"import os\n",
"import time\n",
"import pickle\n",
"\n",
"images = [os.path.join(DOCUMENT_DIR, f) for f in os.listdir(DOCUMENT_DIR)]\n",
"file_to_multi_vectors = {}\n",
"\n",
"for image_path in tqdm(images, desc=\"Create ColPali Vectors\"):\n",
" file_name = os.path.basename(image_path)\n",
" vectors_f32 = create_col_pali_image_vectors(image_path)\n",
" file_to_multi_vectors[file_name] = vectors_f32\n",
"\n",
"with open(\"col_pali_vectors.pkl\", \"wb\") as f:\n",
" pickle.dump(file_to_multi_vectors, f)\n",
"\n",
"print(f\"Saved {len(file_to_multi_vectors)} vector entries to disk\")"
]
},
{
"cell_type": "markdown",
"id": "39512c53-2679-4ae0-a92a-31cb6374b60b",
"metadata": {},
"source": [
"This is the new `rank_vectors` field type, where we will be saving our ColPali vectors. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T09:16:41.682532Z",
"start_time": "2025-03-02T09:14:22.828689Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO] Index 'searchlabs-colpali' already exists.\n"
]
}
],
"source": [
"from dotenv import load_dotenv\n",
"from elasticsearch import Elasticsearch\n",
"\n",
"load_dotenv(\"elastic.env\")\n",
"\n",
"ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n",
"ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n",
"INDEX_NAME = \"searchlabs-colpali\"\n",
"\n",
"es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n",
"\n",
"mappings = {\"mappings\": {\"properties\": {\"col_pali_vectors\": {\"type\": \"rank_vectors\"}}}}\n",
"\n",
"if not es.indices.exists(index=INDEX_NAME):\n",
" print(f\"[INFO] Creating index: {INDEX_NAME}\")\n",
" es.indices.create(index=INDEX_NAME, body=mappings)\n",
"else:\n",
" print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n",
"\n",
"\n",
"def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n",
" for attempt in range(1, retries + 1):\n",
" try:\n",
" return es_client.index(index=index, id=doc_id, document=document)\n",
" except Exception as e:\n",
" if attempt < retries:\n",
" wait_time = initial_backoff * (2 ** (attempt - 1))\n",
" print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n",
" time.sleep(wait_time)\n",
" else:\n",
" print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n",
" raise"
]
},
{
"cell_type": "markdown",
"id": "697f7b77-eb8f-430b-af3d-4a618c5cf086",
"metadata": {},
"source": [
"Load all images back from disk, create the vectors for them and index them into Elasticsearch. "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "adb4b59c-b36f-44d9-bee9-d20457630330",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T09:16:41.682771Z",
"start_time": "2025-03-02T09:14:24.511339Z"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2ee853046b1b4929b0784c41221ba03f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Index documents: 0%| | 0/500 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"with open(\"col_pali_vectors.pkl\", \"rb\") as f:\n",
" file_to_multi_vectors = pickle.load(f)\n",
"\n",
"for file_name, vectors in tqdm(file_to_multi_vectors.items(), desc=\"Index documents\"):\n",
" if es.exists(index=INDEX_NAME, id=file_name):\n",
" continue\n",
"\n",
" index_document(\n",
" es_client=es,\n",
" index=INDEX_NAME,\n",
" doc_id=file_name,\n",
" document={\"col_pali_vectors\": vectors},\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "3b556999-06b7-4856-a6bb-d72ad4929f62",
"metadata": {},
"source": [
"Use the new `maxSimDotProduct` function to calculate the similarity between our query and the image vectors in Elasticsearch. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8e322b23-b4bc-409d-9e00-2dab93f6a295",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T09:16:41.683169Z",
"start_time": "2025-03-02T09:14:24.521130Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_92.jpg\" alt=\"image_92.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from IPython.display import display, HTML\n",
"import os\n",
"\n",
"query = \"What do companies use for recruiting?\"\n",
"es_query = {\n",
" \"_source\": False,\n",
" \"query\": {\n",
" \"script_score\": {\n",
" \"query\": {\"match_all\": {}},\n",
" \"script\": {\n",
" \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n",
" \"params\": {\"query_vector\": create_col_pali_query_vectors(query)},\n",
" },\n",
" }\n",
" },\n",
" \"size\": 5,\n",
"}\n",
"\n",
"results = es.search(index=INDEX_NAME, body=es_query)\n",
"image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n",
"\n",
"html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n",
"for image_id in image_ids:\n",
" image_path = os.path.join(DOCUMENT_DIR, image_id)\n",
" html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n",
"html += \"</div>\"\n",
"\n",
"display(HTML(html))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16997bc1-ea8d-413b-a312-00f08fca1d0a",
"metadata": {},
"outputs": [],
"source": [
"# We kill the kernel forcefully to free up the memory from the ColPali model.\n",
"print(\"Shutting down the kernel to free memory...\")\n",
"import os\n",
"\n",
"os._exit(0)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dependecy-test-colpali-blog",
"language": "python",
"name": "dependecy-test-colpali-blog"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}