<a href="https://colab.research.google.com/github/llermaly/elasticsearch-labs/blob/supporting-blog-content%2Fhow-to-use-jina-v2-embeddings/supporting-blog-content/how-to-use-jina-v2-embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Introduction

In this notebook, we will extend the [Jina Late Chunking implementation example ](https://github.com/jina-ai/late-chunking/blob/main/examples.ipynb) to index the documents and embeddings to Elasticsearch, and run queries against those documents.

The Jina part of the implementation will be keep untouched.

This is supporting material for the following blog post:
https://www.elastic.co/search-labs/blog/how-to-use-jina-v2-embeddings


# [Late Chunking](https://jina.ai/news/late-chunking-in-long-context-embedding-models)

This notebooks explains how the "Late Chunking" can be implemented. First you need to install the requirements:


In [None]:
%pip install transformers==4.43.4

Then we load a model which we want to use for the embedding. We choose `jinaai/jina-embeddings-v2-base-en` but any other model which supports mean pooling is possible. However, models with a large maximum context-length are preferred.


In [None]:
from transformers import AutoModel
from transformers import AutoTokenizer

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
)
model = AutoModel.from_pretrained(
    "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
)

Now we define the text which we want to encode and split it into chunks. The `chunk_by_sentences` function also returns the span annotations.
Those specify the number of tokens per chunk which is needed for the chunked pooling.


In [None]:
def chunk_by_sentences(input_text: str, tokenizer: callable):
    """
    Split the input text into sentences using the tokenizer
    :param input_text: The text snippet to split into sentences
    :param tokenizer: The tokenizer to use
    :return: A tuple containing the list of text chunks and their corresponding token spans
    """
    inputs = tokenizer(input_text, return_tensors="pt", return_offsets_mapping=True)
    punctuation_mark_id = tokenizer.convert_tokens_to_ids(".")
    sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
    token_offsets = inputs["offset_mapping"][0]
    token_ids = inputs["input_ids"][0]
    chunk_positions = [
        (i, int(start + 1))
        for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets))
        if token_id == punctuation_mark_id
        and (
            token_offsets[i + 1][0] - token_offsets[i][1] > 0
            or token_ids[i + 1] == sep_id
        )
    ]
    chunks = [
        input_text[x[1] : y[1]]
        for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]
    span_annotations = [
        (x[0], y[0]) for (x, y) in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]
    return chunks, span_annotations

Now let's try to segement a toy example.


In [None]:
input_text = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."

# determine chunks
chunks, span_annotations = chunk_by_sentences(input_text, tokenizer)
print('Chunks:\n- "' + '"\n- "'.join(chunks) + '"')

Now we encode the chunks with the traditional and the context-sensitive late_chunking method:


In [None]:
def late_chunking(
    model_output: "BatchEncoding", span_annotation: list, max_length=None
):
    token_embeddings = model_output[0]
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if (
            max_length is not None
        ):  # remove annotations which go bejond the max-length of the model
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]
        pooled_embeddings = [
            embeddings[start:end].sum(dim=0) / (end - start)
            for start, end in annotations
            if (end - start) >= 1
        ]
        pooled_embeddings = [
            embedding.detach().cpu().numpy() for embedding in pooled_embeddings
        ]
        outputs.append(pooled_embeddings)

    return outputs

In [None]:
# chunk before
embeddings_traditional_chunking = model.encode(chunks)

# chunk afterwards (context-sensitive chunked pooling)
inputs = tokenizer(input_text, return_tensors="pt")
model_output = model(**inputs)
embeddings = late_chunking(model_output, [span_annotations])[0]

Finally, we compare the similarity of the word "Berlin" with the chunks. The similarity should be higher for the context-sensitive chunked pooling method:


In [None]:
import numpy as np


def cos_sim(x, y):
    return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))


berlin_embedding = model.encode("Berlin")

for chunk, new_embedding, trad_embeddings in zip(
    chunks, embeddings, embeddings_traditional_chunking
):
    print(
        f'similarity_new("Berlin", "{chunk}"):',
        cos_sim(berlin_embedding, new_embedding),
    )
    print(
        f'similarity_trad("Berlin", "{chunk}"):',
        cos_sim(berlin_embedding, trad_embeddings),
    )

# Indexing to Elasticsearch

Now, let's index the brand new embeddings to Elasticsearch and run queries


In [None]:
%pip install elasticsearch

In [None]:
from elasticsearch import Elasticsearch, helpers, exceptions
from getpass import getpass

In [None]:
# https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud#finding-your-cloud-id
ELASTIC_CLOUD_ID = getpass("Elastic Cloud ID: ")

# https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud#creating-an-api-key
ELASTIC_API_KEY = getpass("Elastic Api Key: ")

# Create the client instance
client = Elasticsearch(
    # For local development
    # hosts=["http://localhost:9200"]
    cloud_id=ELASTIC_CLOUD_ID,
    api_key=ELASTIC_API_KEY,
)

## Creating the inference endpoint


In [None]:
API_KEY = getpass("HuggingFace API key:  ")

client.inference.delete(inference_id="jina-embeddings-v2-base-en")
client.inference.put(
    task_type="text_embedding",
    inference_id="jina-embeddings-v2-base-en",
    body={
        "service": "hugging_face",
        "service_settings": {
            "api_key": API_KEY,
            "url": "https://api-inference.huggingface.co/models/jinaai/jina-embeddings-v2-base-en",
        },
    },
)

## Creating index


In [None]:
client.indices.delete(index="jina-late-chunking", ignore_unavailable=True)
client.indices.create(
    index="jina-late-chunking",
    mappings={
        "properties": {
            "content_embedding": {
                "type": "dense_vector",
                "dims": 768,
                "similarity": "cosine",
                "element_type": "float",
            },
            "content": {"type": "text"},
        }
    },
)

## Loading documents


In [None]:
# Prepare the documents to be indexed
documents = []
for chunk, new_embedding in zip(chunks, embeddings):
    documents.append(
        {
            "_index": "jina-late-chunking",
            "_source": {
                "content_embedding": new_embedding,
                "content": chunk,
            },
        }
    )
# Use helpers.bulk to index
helpers.bulk(client, documents)

## Running semantic search


In [None]:
response = client.search(
    index="jina-late-chunking",
    knn={
        "field": "content_embedding",
        "query_vector_builder": {
            "text_embedding": {
                "model_id": "jina-embeddings-v2-base-en",
                "model_text": "berlin",
            }
        },
        "k": 10,
        "num_candidates": 100,
    },
)

print("Late chunking results")
for hit in response["hits"]["hits"]:
    doc_id = hit["_id"]
    score = hit["_score"]
    content = hit["_source"]["content"]
    print(f"Score: {score}\nContent: {content}\n")