notebooks/integrations/gemma/rag-gemma-huggingface-elastic.ipynb (373 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "faa09879-9128-4864-8bb5-945ef9b8e84c", "metadata": {}, "source": [ "# RAG: Using Gemma LLM locally for question answering on private data" ] }, { "cell_type": "markdown", "id": "d047438b-6f18-47ed-aac9-12c741cefd06", "metadata": {}, "source": [ "In this notebook, our aim is to develop a RAG system utilizing [Google's Gemma](https://ai.google.dev/gemma) model. We'll generate vectors with [Elastic's ELSER](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html) model and store them in Elasticsearch. Additionally, we'll explore semantic retrieval techniques and present the top search results as a context window to the Gemma model. Furthermore, we'll utilize the [Hugging Face transformer](https://huggingface.co/google/gemma-2b-it) library to load Gemma on a local environment." ] }, { "cell_type": "markdown", "id": "1bd3acec-d490-4139-bab1-b874e1e7db8d", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "markdown", "id": "ef406b8a-03fb-49c5-baed-18e03bcd36d9", "metadata": {}, "source": [ "**Elastic Credentials** - Create an [Elastic Cloud deployment](https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud) to get all Elastic credentials (`ELASTIC_CLOUD_ID`,` ELASTIC_API_KEY`).\n", "\n", "**Hugging Face Token** - To get started with the [Gemma](https://huggingface.co/google/gemma-2b-it) model, it is necessary to agree to the terms on Hugging Face and generate the [access token](https://huggingface.co/docs/hub/en/security-tokens) with `write` role.\n", "\n", "**Gemma Model** - We're going to use [gemma-2b-it](https://huggingface.co/google/gemma-2b-it), though Google has released 4 open models. You can use any of them i.e. [gemma-2b](https://huggingface.co/google/gemma-2b), [gemma-7b](https://huggingface.co/google/gemma-7b), [gemma-7b-it](https://huggingface.co/google/gemma-7b-it)" ] }, { "cell_type": "markdown", "id": "ac91d7a3-1198-4b11-a9c5-50028abc861b", "metadata": {}, "source": [ "## Install packages" ] }, { "cell_type": "code", "execution_count": null, "id": "fda41538-444c-48d7-80a0-b34b2e158b82", "metadata": {}, "outputs": [], "source": [ "pip install -q -U elasticsearch langchain transformers huggingface_hub torch" ] }, { "cell_type": "markdown", "id": "15c2e924-e5a2-439b-8e98-f13a162db7fe", "metadata": {}, "source": [ "## Import packages" ] }, { "cell_type": "code", "execution_count": null, "id": "7219411b-fae6-4c2a-b170-796bc30ed073", "metadata": {}, "outputs": [], "source": [ "import json\n", "import os\n", "from getpass import getpass\n", "from urllib.request import urlopen\n", "\n", "from elasticsearch import Elasticsearch, helpers\n", "from langchain.text_splitter import CharacterTextSplitter\n", "from langchain.vectorstores import ElasticsearchStore\n", "from langchain import HuggingFacePipeline\n", "from langchain.chains import RetrievalQA\n", "from langchain.prompts import ChatPromptTemplate\n", "from langchain.schema.output_parser import StrOutputParser\n", "from langchain.schema.runnable import RunnablePassthrough\n", "from huggingface_hub import login\n", "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "from transformers import AutoTokenizer, pipeline" ] }, { "cell_type": "markdown", "id": "182a413f-e7fd-4361-8096-90736d3df33e", "metadata": {}, "source": [ "## Get Credentials" ] }, { "cell_type": "code", "execution_count": null, "id": "b184b3a5-0cc8-43f9-b15d-f5ccf48f574b", "metadata": {}, "outputs": [], "source": [ "ELASTIC_API_KEY = getpass(\"Elastic API Key :\")\n", "ELASTIC_CLOUD_ID = getpass(\"Elastic Cloud ID :\")\n", "elastic_index_name = \"gemma-rag\"" ] }, { "cell_type": "markdown", "id": "a2efbd81-70b9-409c-ab5f-796d538b42a1", "metadata": {}, "source": [ "## Add documents" ] }, { "cell_type": "markdown", "id": "161dfb9d-f11f-4de5-8489-6464ade0cdb2", "metadata": {}, "source": [ "### Let's download the sample dataset and deserialize the document." ] }, { "cell_type": "code", "execution_count": null, "id": "49427546-7b37-48f4-a6fe-395736ea2d38", "metadata": {}, "outputs": [], "source": [ "url = \"https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json\"\n", "\n", "response = urlopen(url)\n", "\n", "workplace_docs = json.loads(response.read())" ] }, { "cell_type": "markdown", "id": "f3bf0104-8b31-4b39-ad21-b372fd1fa0db", "metadata": {}, "source": [ "### Split Documents into Passages" ] }, { "cell_type": "code", "execution_count": null, "id": "79e55ed1-418e-48ed-b3e3-d28e10744eb5", "metadata": {}, "outputs": [], "source": [ "metadata = []\n", "content = []\n", "\n", "for doc in workplace_docs:\n", " content.append(doc[\"content\"])\n", " metadata.append(\n", " {\n", " \"name\": doc[\"name\"],\n", " \"summary\": doc[\"summary\"],\n", " \"rolePermissions\": doc[\"rolePermissions\"],\n", " }\n", " )\n", "\n", "text_splitter = CharacterTextSplitter(chunk_size=50, chunk_overlap=0)\n", "docs = text_splitter.create_documents(content, metadatas=metadata)" ] }, { "cell_type": "markdown", "id": "4264bc1b-23b1-4547-a7f0-670944c3e605", "metadata": {}, "source": [ "## Index Documents into Elasticsearch using ELSER\n", "\n", "Before we begin indexing, ensure you have [downloaded and deployed the ELSER model](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html#download-deploy-elser) in your deployment and is running on the ML node." ] }, { "cell_type": "code", "execution_count": null, "id": "eb1db78e-e40a-4a5c-9d15-75ee2a1d0994", "metadata": {}, "outputs": [], "source": [ "es = ElasticsearchStore.from_documents(\n", " docs,\n", " es_cloud_id=ELASTIC_CLOUD_ID,\n", " es_api_key=ELASTIC_API_KEY,\n", " index_name=elastic_index_name,\n", " strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(\n", " model_id=\".elser_model_2\"\n", " ),\n", ")\n", "\n", "es" ] }, { "cell_type": "markdown", "id": "02b1ead9-c442-40e9-ba81-d4d286ea878b", "metadata": {}, "source": [ "## Hugging Face login" ] }, { "cell_type": "code", "execution_count": null, "id": "d2f651e4-e760-4b59-a8a3-57c58dfc229f", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "id": "7454f551-71a9-4310-bb2a-3fe0e683daab", "metadata": {}, "source": [ "## Initialize the tokenizer with the model (`google/gemma-2b-it`)" ] }, { "cell_type": "code", "execution_count": null, "id": "3e1d98eb-0f4e-4c41-a851-125b75502963", "metadata": {}, "outputs": [], "source": [ "model = AutoModelForCausalLM.from_pretrained(\"google/gemma-2b-it\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2b-it\")" ] }, { "cell_type": "markdown", "id": "11a12596-2ac1-4101-b189-d21d53d33b04", "metadata": {}, "source": [ "## Create a `text-generation` pipeline and initialize with LLM" ] }, { "cell_type": "code", "execution_count": null, "id": "623e74fb-5707-44f7-9dd8-d9499f7ab61e", "metadata": {}, "outputs": [], "source": [ "pipe = pipeline(\n", " \"text-generation\",\n", " model=model,\n", " tokenizer=tokenizer,\n", " max_new_tokens=1024,\n", ")\n", "\n", "llm = HuggingFacePipeline(\n", " pipeline=pipe,\n", " model_kwargs={\"temperature\": 0.7},\n", ")" ] }, { "cell_type": "markdown", "id": "49ce0e72-e419-4310-85e9-09077d6c40b2", "metadata": {}, "source": [ "## Format Docs" ] }, { "cell_type": "code", "execution_count": null, "id": "b3c07a75-9220-4a82-a92e-3fc2727ad3ba", "metadata": {}, "outputs": [], "source": [ "def format_docs(docs):\n", " return \"\\n\\n\".join(doc.page_content for doc in docs)" ] }, { "cell_type": "markdown", "id": "f6266222-6ec3-495a-8f14-460549bab89d", "metadata": {}, "source": [ "## Create a chain using Prompt template" ] }, { "cell_type": "code", "execution_count": null, "id": "ec203d1a-104b-4583-9ba1-a6b4b0354367", "metadata": {}, "outputs": [], "source": [ "retriever = es.as_retriever(search_kwargs={\"k\": 5})\n", "\n", "template = \"\"\"Answer the question based only on the following context:\\n\n", "\n", "{context}\n", "\n", "Question: {question}\n", "\"\"\"\n", "\n", "prompt = ChatPromptTemplate.from_template(template)\n", "\n", "chain = (\n", " {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n", " | prompt\n", " | llm\n", " | StrOutputParser()\n", ")" ] }, { "cell_type": "markdown", "id": "8ae892dd-7442-4d4d-a804-1d717266e596", "metadata": {}, "source": [ "## Ask question" ] }, { "cell_type": "code", "execution_count": 55, "id": "ba312f17-44ae-423d-89a0-ea01eccd85b5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Answer: The sales goals are to increase revenue, expand market share, and strengthen customer relationships in our target markets.'" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke(\"What are the sales goals?\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }