rag/augment.ipynb (577 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Augment retrieval results by reranking using Sentence Transformers\n", "\n", "Retrievals are quick estimates of the most relevant documents to a query which works fine for a first pass over millions of documents, but we can improve this relevance by reranking the retrieved documents. We will build a reranker which can be used in a RAG pipeline together with the retrieval microservice of the [Index and retrieve documents for vector search using Sentence Transformers and DuckDB](./retrieve.ipynb) notebook. At the end we will deploy a microservice that can be used to perform reranking of documents based on a query.\n", "\n", "## Dependencies and imports\n", "\n", "Let's install the necessary dependencies." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install gradio gradio-client pandas sentence-transformers -q" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's import the necessary libraries." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import pandas as pd\n", "\n", "from gradio_client import Client\n", "from sentence_transformers import CrossEncoder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hugging Face as a vector search backend\n", "\n", "A brief recap of the previous notebook, we use Hugging Face as vector search backend and can call it as a REST API through the Gradio Python Client." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded as API: https://ai-blueprint-rag-retrieve.hf.space/ ✔\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>url</th>\n", " <th>text</th>\n", " <th>distance</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>https://www.bbc.com/news/technology-51064369</td>\n", " <td>The last decade was a big one for artificial i...</td>\n", " <td>0.281200</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>http://www.bbc.co.uk/news/technology-25000756</td>\n", " <td>Singularity: The robots are coming to steal ou...</td>\n", " <td>0.365842</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>http://www.bbc.com/news/technology-25000756</td>\n", " <td>Singularity: The robots are coming to steal ou...</td>\n", " <td>0.365842</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>https://www.bbc.co.uk/news/technology-37494863</td>\n", " <td>Google, Facebook, Amazon join forces on future...</td>\n", " <td>0.380820</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>https://www.bbc.co.uk/news/technology-37494863</td>\n", " <td>Google, Facebook, Amazon join forces on future...</td>\n", " <td>0.380820</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " url \\\n", "0 https://www.bbc.com/news/technology-51064369 \n", "1 http://www.bbc.co.uk/news/technology-25000756 \n", "2 http://www.bbc.com/news/technology-25000756 \n", "3 https://www.bbc.co.uk/news/technology-37494863 \n", "4 https://www.bbc.co.uk/news/technology-37494863 \n", "\n", " text distance \n", "0 The last decade was a big one for artificial i... 0.281200 \n", "1 Singularity: The robots are coming to steal ou... 0.365842 \n", "2 Singularity: The robots are coming to steal ou... 0.365842 \n", "3 Google, Facebook, Amazon join forces on future... 0.380820 \n", "4 Google, Facebook, Amazon join forces on future... 0.380820 " ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gradio_client = Client(\"https://ai-blueprint-rag-retrieve.hf.space/\")\n", "\n", "\n", "def similarity_search(query: str, k: int = 5) -> pd.DataFrame:\n", " results = gradio_client.predict(api_name=\"/similarity_search\", query=query, k=k)\n", " return pd.DataFrame(data=results[\"data\"], columns=results[\"headers\"])\n", "\n", "\n", "similarity_search(\"What is the future of AI?\", k=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reranking retrieved documents\n", "\n", "Whenever we retrieve documents from the vector search backend, we can improve the quality of the documents that we pass to the LLM. We do that by ranking the documents by relevance to the query. We will use the [sentence-transformers library](https://huggingface.co/sentence-transformers). You can find the best models to do this, using the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard). \n", "\n", "We will first retrieve 50 documents and then use [sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2) to rerank the documents and return the top 5." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L12-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>url</th>\n", " <th>text</th>\n", " <th>distance</th>\n", " <th>rank</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>8</th>\n", " <td>http://www.bbc.com/news/world-us-canada-39425862</td>\n", " <td>Vector Institute is just the latest in Canada'...</td>\n", " <td>0.424994</td>\n", " <td>0.508780</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>http://www.bbc.com/news/business-34266425</td>\n", " <td>Google’s Demis Hassabis – misuse of artificial...</td>\n", " <td>0.442649</td>\n", " <td>0.508423</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>http://news.bbc.co.uk/2/hi/uk_news/england/wea...</td>\n", " <td>A group of scientists in the north-east of Eng...</td>\n", " <td>0.484410</td>\n", " <td>0.508336</td>\n", " </tr>\n", " <tr>\n", " <th>21</th>\n", " <td>https://www.bbc.com/news/technology-47668476</td>\n", " <td>How Pope Francis could shape the future of rob...</td>\n", " <td>0.494108</td>\n", " <td>0.508200</td>\n", " </tr>\n", " <tr>\n", " <th>42</th>\n", " <td>http://news.bbc.co.uk/2/hi/technology/6583893.stm</td>\n", " <td>Scientists have expressed concern about the us...</td>\n", " <td>0.530431</td>\n", " <td>0.507771</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " url \\\n", "8 http://www.bbc.com/news/world-us-canada-39425862 \n", "12 http://www.bbc.com/news/business-34266425 \n", "19 http://news.bbc.co.uk/2/hi/uk_news/england/wea... \n", "21 https://www.bbc.com/news/technology-47668476 \n", "42 http://news.bbc.co.uk/2/hi/technology/6583893.stm \n", "\n", " text distance rank \n", "8 Vector Institute is just the latest in Canada'... 0.424994 0.508780 \n", "12 Google’s Demis Hassabis – misuse of artificial... 0.442649 0.508423 \n", "19 A group of scientists in the north-east of Eng... 0.484410 0.508336 \n", "21 How Pope Francis could shape the future of rob... 0.494108 0.508200 \n", "42 Scientists have expressed concern about the us... 0.530431 0.507771 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reranker = CrossEncoder(\"sentence-transformers/all-MiniLM-L12-v2\")\n", "\n", "\n", "def rerank(query: str, documents: pd.DataFrame) -> pd.DataFrame:\n", " documents = documents.copy()\n", " documents = documents.drop_duplicates(\"text\")\n", " documents[\"rank\"] = reranker.predict([[query, hit] for hit in documents[\"text\"]])\n", " documents = documents.sort_values(by=\"rank\", ascending=False)\n", " return documents\n", "\n", "\n", "query = \"What is the future of AI?\"\n", "documents = similarity_search(query, k=50)\n", "reranked_documents = rerank(query=query, documents=documents)\n", "reranked_documents[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see the returned documents have slightly shifted in the ranking, which is good, because we see that our reranking works." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating a web app and microservice for reranking\n", "\n", "We will be using [Gradio](https://github.com/gradio-app/gradio) as web application tool to create a demo interface for our reranking. We can develop this locally and then easily deploy it to Hugging Face Spaces. Lastly, we can use the Gradio client as SDK to directly interact with our reranking microservice.\n", "\n", "### Creating the web app" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7862\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "<div><iframe src=\"http://127.0.0.1:7862/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"# RAG - Augment \n", " \n", " Applies reranking to the retrieved documents using [sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2)\n", " \n", " Part of [AI blueprint](https://github.com/davidberenstein1957/ai-blueprint) - a blueprint for AI development, focusing on applied examples of RAG, information extraction, analysis and fine-tuning in the age of LLMs and agents..\"\"\")\n", "\n", " query_input = gr.Textbox(\n", " label=\"Query\", placeholder=\"Enter your question here...\", lines=3\n", " )\n", " documents_input = gr.Dataframe(\n", " label=\"Documents\", headers=[\"text\"], wrap=True, interactive=True\n", " )\n", "\n", " submit_btn = gr.Button(\"Submit\")\n", " documents_output = gr.Dataframe(\n", " label=\"Documents\", headers=[\"text\", \"rank\"], wrap=True\n", " )\n", "\n", " submit_btn.click(\n", " fn=rerank_documents,\n", " inputs=[query_input, documents_input],\n", " outputs=[documents_output],\n", " )\n", "\n", "demo.launch(share=False) # share=True is used to share the app with the public" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<iframe\n", "\tsrc=\"https://ai-blueprint-rag-augment.hf.space\"\n", "\tframeborder=\"0\"\n", "\twidth=\"850\"\n", "\theight=\"450\"\n", "></iframe>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Deploying the web app to Hugging Face\n", "\n", "We can now [deploy our Gradio application to Hugging Face Spaces](https://huggingface.co/new-space?sdk=gradio&name=rag-augment).\n", "\n", "- Click on the \"Create Space\" button.\n", "- Copy the code from the Gradio interface and paste it into an `app.py` file. Don't forget to copy the `generate_response_*` function, along with the code to execute the RAG pipeline.\n", "- Create a `requirements.txt` file with `gradio-client` and `sentence-transformers`.\n", "- Set a Hugging Face API as `HF_TOKEN` secret variable in the space settings, if you are using the Inference API.\n", "\n", "We wait a couple of minutes for the application to deploy et voila, we have [a public reranking interface](https://huggingface.co/spaces/ai-blueprint/rag-augment)!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using the web app as a microservice\n", "\n", "We can now use the [Gradio client as SDK](https://www.gradio.app/guides/getting-started-with-the-python-client) to directly interact with our RAG pipeline. Each Gradio app has a API documentation that describes the available endpoints and their parameters, which you can access from the button at the bottom of the Gradio app's space page." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded as API: https://ai-blueprint-rag-augment.hf.space/ ✔\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>url</th>\n", " <th>text</th>\n", " <th>distance</th>\n", " <th>rank</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>https://www.bbc.co.uk/news/business-48139212</td>\n", " <td>Artificial intelligence (AI) is one of the mos...</td>\n", " <td>0.407243</td>\n", " <td>0.511831</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>http://www.bbc.com/news/technology-39657505</td>\n", " <td>Ted 2017: The robot that wants to go to univer...</td>\n", " <td>0.424357</td>\n", " <td>0.509631</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>http://www.bbc.com/news/world-us-canada-39425862</td>\n", " <td>Vector Institute is just the latest in Canada'...</td>\n", " <td>0.424994</td>\n", " <td>0.508584</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>https://www.bbc.co.uk/news/technology-37494863</td>\n", " <td>Google, Facebook, Amazon join forces on future...</td>\n", " <td>0.380820</td>\n", " <td>0.507728</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>https://www.bbc.com/news/technology-51064369</td>\n", " <td>The last decade was a big one for artificial i...</td>\n", " <td>0.281200</td>\n", " <td>0.506788</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>http://www.bbc.co.uk/news/technology-25000756</td>\n", " <td>Singularity: The robots are coming to steal ou...</td>\n", " <td>0.365842</td>\n", " <td>0.506259</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>https://www.bbc.com/news/technology-52415775</td>\n", " <td>UK spies will need to use artificial intellige...</td>\n", " <td>0.414651</td>\n", " <td>0.505149</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " url \\\n", "0 https://www.bbc.co.uk/news/business-48139212 \n", "1 http://www.bbc.com/news/technology-39657505 \n", "2 http://www.bbc.com/news/world-us-canada-39425862 \n", "3 https://www.bbc.co.uk/news/technology-37494863 \n", "4 https://www.bbc.com/news/technology-51064369 \n", "5 http://www.bbc.co.uk/news/technology-25000756 \n", "6 https://www.bbc.com/news/technology-52415775 \n", "\n", " text distance rank \n", "0 Artificial intelligence (AI) is one of the mos... 0.407243 0.511831 \n", "1 Ted 2017: The robot that wants to go to univer... 0.424357 0.509631 \n", "2 Vector Institute is just the latest in Canada'... 0.424994 0.508584 \n", "3 Google, Facebook, Amazon join forces on future... 0.380820 0.507728 \n", "4 The last decade was a big one for artificial i... 0.281200 0.506788 \n", "5 Singularity: The robots are coming to steal ou... 0.365842 0.506259 \n", "6 UK spies will need to use artificial intellige... 0.414651 0.505149 " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "client = Client(\"https://ai-blueprint-rag-augment.hf.space/\")\n", "\n", "df = similarity_search(\"What is the future of AI?\", k=10)\n", "data = client.predict(\n", " query=\"What is the future of AI?\",\n", " documents={\"headers\": df.columns.tolist(), \"data\": df.values.tolist(), \"metadata\": None},\n", " api_name=\"/rerank\",\n", ")\n", "pd.DataFrame(data=data[\"data\"], columns=data[\"headers\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "We have seen how to create a reranker using the sentence-transformers library and how to deploy it as a microservice on Hugging Face Spaces. Next steps will be to create a model that can be used to generate a response to a query.\n", "\n", "## Next Steps\n", "\n", "- Continue - with [Generate a responses based on retrieved documents using a SmolLM](./generate.ipynb).\n", "- Contribute - missing something? PRs are always welcome.\n", "- Learn - theories behind the approaches in [Hugging Face courses](https://huggingface.co/learn) or [smol-course](https://github.com/huggingface/smol-course?tab=readme-ov-file).\n", "- Explore - notebooks with similar techniques on [the Hugging Face Cookbook](https://huggingface.co/learn/cookbook/index)." ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 2 }