# Scalable late interaction vectors in Elasticsearch: Average Vectors #

In this notebook, we will be looking at how scale search with late interaction models. We will be taking the average vector over our late interaction multi-vectors and use Elasticsearchs vector search capabilities to achieve scalable search over billions of vectors. 
   
This notebook builds on part 1 where we downloaded the images, created ColPali vectors and saved them to disk. Please execute this notebook before trying the techniques in this notebook. 

Also check out our accompanying blog post on [Scaling Late Interaction Models](TODO) for more context on this notebook. 

This is the key part of this notebook. We use `to_avg_vector(vectors)` to convert our 2d array into a single vector that holds the "average meaning" of the document. 

In [1]:
import numpy as np


def to_bit_vectors(embeddings: list) -> list:
    return [
        np.packbits(np.where(np.array(embedding) > 0, 1, 0))
        .astype(np.int8)
        .tobytes()
        .hex()
        for embedding in embeddings
    ]


def to_avg_vector(vectors):
    vectors_array = np.array(vectors)

    avg_vector = np.mean(vectors_array, axis=0)

    norm = np.linalg.norm(avg_vector)
    if norm > 0:
        normalized_avg_vector = avg_vector / norm
    else:
        normalized_avg_vector = avg_vector

    return normalized_avg_vector.tolist()

Here we are defining our mapping for our Elasticsearch index. Note how we are using the `dense_vector` field type for our average vector.  
This allows us to leverage the highly optimized HNSW indexing structures in Elasticsearch with which we can scale our search to billions of documents. 

In [2]:
import os
from dotenv import load_dotenv
from elasticsearch import Elasticsearch

load_dotenv("elastic.env")

ELASTIC_API_KEY = os.getenv("ELASTIC_API_KEY")
ELASTIC_HOST = os.getenv("ELASTIC_HOST")
INDEX_NAME = "searchlabs-colpali-average-vector"

es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)

mappings = {
    "mappings": {
        "properties": {
            "avg_vector": {
                "type": "dense_vector",
                "dims": 128,
                "index": True,
                "similarity": "dot_product",
            },
            "col_pali_vectors": {"type": "rank_vectors", "element_type": "bit"},
        }
    }
}

if not es.indices.exists(index=INDEX_NAME):
    print(f"[INFO] Creating index: {INDEX_NAME}")
    es.indices.create(index=INDEX_NAME, body=mappings)
else:
    print(f"[INFO] Index '{INDEX_NAME}' already exists.")


def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):
    for attempt in range(1, retries + 1):
        try:
            return es_client.index(index=index, id=doc_id, document=document)
        except Exception as e:
            if attempt < retries:
                wait_time = initial_backoff * (2 ** (attempt - 1))
                print(f"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}")
                time.sleep(wait_time)
            else:
                print(f"Failed to index {doc_id} after {retries} attempts: {e}")
                raise

[INFO] Creating index: searchlabs-colpali-average-vector


Here we are looping over all our vectors and convert them into our average vectors.  
We still save our full fidelity ColPali vectors as bit vectors as we can use them for reranking. More on that later. 

In [3]:
from concurrent.futures import ThreadPoolExecutor
from tqdm.notebook import tqdm
import pickle


def process_file(file_name, vectors):
    if es.exists(index=INDEX_NAME, id=file_name):
        return

    avg_vector = to_avg_vector(vectors)

    bit_vectors = to_bit_vectors(vectors)

    index_document(
        es_client=es,
        index=INDEX_NAME,
        doc_id=file_name,
        document={"avg_vector": avg_vector, "col_pali_vectors": bit_vectors},
    )


with open("col_pali_vectors.pkl", "rb") as f:
    file_to_multi_vectors = pickle.load(f)

with ThreadPoolExecutor(max_workers=10) as executor:
    list(
        tqdm(
            executor.map(
                lambda item: process_file(*item), file_to_multi_vectors.items()
            ),
            total=len(file_to_multi_vectors),
            desc="Indexing documents",
        )
    )

print(f"Completed indexing {len(file_to_multi_vectors)} documents")

Indexing documents:   0%|          | 0/500 [00:00<?, ?it/s]

Completed indexing 500 documents


In [4]:
import torch
from PIL import Image
from colpali_engine.models import ColPali, ColPaliProcessor

model_name = "vidore/colpali-v1.3"
model = ColPali.from_pretrained(
    "vidore/colpali-v1.3",
    torch_dtype=torch.float32,
    device_map="mps",  # "mps" for Apple Silicon, "cuda" if available, "cpu" otherwise
).eval()

col_pali_processor = ColPaliProcessor.from_pretrained(model_name)


def create_col_pali_query_vectors(query: str) -> list:
    queries = col_pali_processor.process_queries([query]).to(model.device)
    with torch.no_grad():
        return model(**queries).tolist()[0]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Again we create our query vector by using the ColPali model and calculating the average vector. We can now use Elasticsearches `knn` query to compare our single vector to the average vectors within our index. 
Notice that the document that we have found in our previous examples with the title *Social Media for Recruitment* is now in 5th position. See the next cell on how we can handle this. 

In [5]:
from IPython.display import display, HTML
import os
import json

DOCUMENT_DIR = "searchlabs-colpali"

query = "What do companies use for recruiting?"
query_vector = to_avg_vector(create_col_pali_query_vectors(query))
es_query = {
    "_source": False,
    "knn": {
        "field": "avg_vector",
        "query_vector": query_vector,
        "k": 10,
        "num_candidates": 100,
    },
    "size": 5,
}

results = es.search(index=INDEX_NAME, body=es_query)
image_ids = [hit["_id"] for hit in results["hits"]["hits"]]

html = "<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>"
for image_id in image_ids:
    image_path = os.path.join(DOCUMENT_DIR, image_id)
    html += f'<img src="{image_path}" alt="{image_id}" style="max-width:300px; height:auto; margin:10px;">'
html += "</div>"

display(HTML(html))

In the cell above we have seen that the document *Social Media for Recruitment* is not considered the most relevant anymore.  

What we want to do to handle this is to run our KNN vector stage as a **first stage retrieval** algorithm. This is the `knn retriever` in the example below. 
We then run a **first stage retrieval** rescoring algorithm on a smaller set of these documents. For this we can use higher fidelity ColPali vectors. This is the `rescore` section in the example below. 

In [6]:
query = "What do companies use for recruiting?"
col_pali_vector = create_col_pali_query_vectors(query)
avg_vector = to_avg_vector(col_pali_vector)
es_query = {
    "_source": False,
    "retriever": {
        "rescorer": {
            "retriever": {
                "knn": {
                    "field": "avg_vector",
                    "query_vector": avg_vector,
                    "k": 10,
                    "num_candidates": 100,
                }
            },
            "rescore": {
                "window_size": 10,
                "query": {
                    "rescore_query": {
                        "script_score": {
                            "query": {"match_all": {}},
                            "script": {
                                "source": "maxSimDotProduct(params.query_vector, 'col_pali_vectors')",
                                "params": {"query_vector": col_pali_vector},
                            },
                        }
                    }
                },
            },
        }
    },
    "size": 5,
}

results = es.search(index=INDEX_NAME, body=es_query)
image_ids = [hit["_id"] for hit in results["hits"]["hits"]]

html = "<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>"
for image_id in image_ids:
    image_path = os.path.join(DOCUMENT_DIR, image_id)
    html += f'<img src="{image_path}" alt="{image_id}" style="max-width:300px; height:auto; margin:10px;">'
html += "</div>"

display(HTML(html))

In [None]:
# We kill the kernel forcefully to free up the memory from the ColPali model.
print("Shutting down the kernel to free memory...")
import os

os._exit(0)