# Scalable late interaction vectors in Elasticsearch: Token Pooling #

In this notebook, we will be looking at how scale search with late interaction models. We will be looking a token pooling - a technique to reduce the dimensionality of the late interaction multi-vectors by clustering similar information. This technique can of course be combined with the other techniques we have discussed in the previous notebooks. 

This notebook builds on part 1 where we downloaded the images, created ColPali vectors and saved them to disk. Please execute this notebook before trying the techniques in this notebook.  

Also check out our accompanying blog post on [Scaling Late Interaction Models](TODO) for more context on this notebook. 

In [1]:
import numpy as np


def to_bit_vectors(embeddings: list) -> list:
    return [
        np.packbits(np.where(np.array(embedding) > 0, 1, 0))
        .astype(np.int8)
        .tobytes()
        .hex()
        for embedding in embeddings
    ]

We will be using the `HierarchicalTokenPooler` from the [colpali-engine](https://github.com/illuin-tech/colpali?tab=readme-ov-file#token-pooling) to reduce the dimensions of our vector.  
The authors recommend a `pool_factor=3` for most cases, but you should always tests how it impact the relevancy of your dataset. 

In [2]:
import torch
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler

pooler = HierarchicalTokenPooler(
    pool_factor=3
)  # test on your data for a good pool_factor


def pool_vectors(embedding: list) -> list:
    tensor = torch.tensor(embedding).unsqueeze(0)
    pooled = pooler.pool_embeddings(tensor)
    return pooled.squeeze(0).tolist()

In [3]:
import os
from dotenv import load_dotenv
from elasticsearch import Elasticsearch

load_dotenv("elastic.env")

ELASTIC_API_KEY = os.getenv("ELASTIC_API_KEY")
ELASTIC_HOST = os.getenv("ELASTIC_HOST")
INDEX_NAME = "searchlabs-colpali-token-pooling"

es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)

mappings = {
    "mappings": {
        "properties": {
            "pooled_col_pali_vectors": {"type": "rank_vectors", "element_type": "bit"}
        }
    }
}

if not es.indices.exists(index=INDEX_NAME):
    print(f"[INFO] Creating index: {INDEX_NAME}")
    es.indices.create(index=INDEX_NAME, body=mappings)
else:
    print(f"[INFO] Index '{INDEX_NAME}' already exists.")


def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):
    for attempt in range(1, retries + 1):
        try:
            return es_client.index(index=index, id=doc_id, document=document)
        except Exception as e:
            if attempt < retries:
                wait_time = initial_backoff * (2 ** (attempt - 1))
                print(f"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}")
                time.sleep(wait_time)
            else:
                print(f"Failed to index {doc_id} after {retries} attempts: {e}")
                raise

[INFO] Creating index: searchlabs-colpali-token-pooling


In [None]:
from concurrent.futures import ThreadPoolExecutor
from tqdm.notebook import tqdm
import pickle


def process_file(file_name, vectors):
    if es.exists(index=INDEX_NAME, id=file_name):
        return

    pooled_vectors = pool_vectors(vectors)

    bit_vectors = to_bit_vectors(pooled_vectors)

    index_document(
        es_client=es,
        index=INDEX_NAME,
        doc_id=file_name,
        document={"pooled_col_pali_vectors": bit_vectors},
    )


with open("col_pali_vectors.pkl", "rb") as f:
    file_to_multi_vectors = pickle.load(f)

with ThreadPoolExecutor(max_workers=10) as executor:
    list(
        tqdm(
            executor.map(
                lambda item: process_file(*item), file_to_multi_vectors.items()
            ),
            total=len(file_to_multi_vectors),
            desc="Indexing documents",
        )
    )

print(f"Completed indexing {len(file_to_multi_vectors)} documents")

Indexing documents:   0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
import torch
from PIL import Image
from colpali_engine.models import ColPali, ColPaliProcessor

model_name = "vidore/colpali-v1.3"
model = ColPali.from_pretrained(
    "vidore/colpali-v1.3",
    torch_dtype=torch.float32,
    device_map="mps",  # "mps" for Apple Silicon, "cuda" if available, "cpu" otherwise
).eval()

col_pali_processor = ColPaliProcessor.from_pretrained(model_name)


def create_col_pali_query_vectors(query: str) -> list:
    queries = col_pali_processor.process_queries([query]).to(model.device)
    with torch.no_grad():
        return model(**queries).tolist()[0]

In [None]:
from IPython.display import display, HTML
import os
import json

DOCUMENT_DIR = "searchlabs-colpali"

query = "What do companies use for recruiting?"
query_vector = create_col_pali_query_vectors(query)
es_query = {
    "_source": False,
    "query": {
        "script_score": {
            "query": {"match_all": {}},
            "script": {
                "source": "maxSimDotProduct(params.query_vector, 'pooled_col_pali_vectors')",
                "params": {"query_vector": query_vector},
            },
        }
    },
    "size": 5,
}

results = es.search(index=INDEX_NAME, body=es_query)
image_ids = [hit["_id"] for hit in results["hits"]["hits"]]

html = "<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>"
for image_id in image_ids:
    image_path = os.path.join(DOCUMENT_DIR, image_id)
    html += f'<img src="{image_path}" alt="{image_id}" style="max-width:300px; height:auto; margin:10px;">'
html += "</div>"

display(HTML(html))

In [None]:
# We kill the kernel forcefully to free up the memory from the ColPali model.
print("Shutting down the kernel to free memory...")
import os

os._exit(0)