contrib/recipes/gemma_rag/gemma_local_rag.ipynb (497 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "c0fe01ad", "metadata": { "tags": [] }, "outputs": [], "source": [ "##################################################################################\n", "# Copyright 2024 Google LLC\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "###################################################################################" ] }, { "cell_type": "markdown", "id": "ae6f0ced", "metadata": {}, "source": [ "## Open source local RAG with `gemma` and `T5` models" ] }, { "cell_type": "markdown", "id": "1244f675", "metadata": {}, "source": [ "### Notebook overview\n", "\n", "This notebook shows how to implement a local RAG procedure with Open Source models to extract common review themes by product name.\n", "\n", "It performs the following steps:\n", "\n", "- **1 - Corpus embeddings generation:** This step generates embeddings for a CSV extract of `review_text` column the `data_beans.customer_review` table. This step uses the Google Sentence-T5 embedding model to projects the reviews into a 768 dimensional space.\n", " - Model details [here](https://arxiv.org/abs/2108.08877).\n", " - Embeddings vector are locally stored on a chromadb vector database. \n", "\n", "- **2 - Context retrieval:** This step generates the embedding (using the same T5-Sentece model) for the query and retrieves the top K most similar items from the vector database.\n", "\n", "- **3 - Result generation:** This step uses the retrieved context in the previoup step and perform task resolution using Google Gemma 2-b instructioned tuned model.\n", " - Model details [here](https://arxiv.org/abs/2403.08295).\n", "\n", "\n", "\n", "#### Architecture\n", "\n", "![assets/gemma_local_rag.png](assets/gemma_local_rag.png)\n", "\n", "*NOTE: As this notebook is running locally, inference performance will be determined by the underlying hardware (e.g. GPU)*" ] }, { "cell_type": "markdown", "id": "48c4fb41", "metadata": {}, "source": [ "#### Installation\n", "Install the following packages required to execute this notebook." ] }, { "cell_type": "code", "execution_count": null, "id": "d104f92a", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Install the packages\n", "! pip install chromadb==0.4.24\n", "! pip install transformers==4.39.1\n", "! pip install sentence-transformers==2.6.0\n", "! pip install torch==2.2.1\n", "! pip install huggingface-hub==0.22.0\n", "! pip install ipywidgets" ] }, { "cell_type": "markdown", "id": "ed1bfdde", "metadata": {}, "source": [ "#### Import libraries and define variables\n", "Import python libraries, definition of notebook variables and device (GPU) setup" ] }, { "cell_type": "code", "execution_count": 1, "id": "00699daf", "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "import logging\n", "import chromadb\n", "import csv\n", "import torch\n", "from tqdm import tqdm \n", "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n", "from sentence_transformers import SentenceTransformer\n", "\n", "\n", "REVIEWS_TEXT_FILE = \"data/customer_reviews.csv\"\n", "CHROMA_DIR=\"chroma\"\n", "EMBEDDING_MODEL = \"sentence-transformers/sentence-t5-xl\"\n", "GENERATION_MODEL = \"google/gemma-2b-it\"\n", "TOP_K_RETRIEVE = 3\n", "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": 2, "id": "7e7658fa", "metadata": { "tags": [] }, "outputs": [], "source": [ "logging.getLogger().setLevel(logging.ERROR)" ] }, { "cell_type": "markdown", "id": "726fef2b-b6f9-4d23-91fd-aa7c2e6ab23d", "metadata": {}, "source": [ "#### Log-in into HugginFace to get the gemma model\n", "This notebook uses gemma via the popular Hugginface `transformers` library.\n", "You need a Hugginface account to download the model weights.\n", "\n", "Gemma is located in a gated repo, so you also need to accept Gemma usage terms. [Hugginface gemma](https://huggingface.co/google/gemma-2b-it).\n", "\n", "Once usage terms are accepted, navigate to `Profile > Settings > Acces Tokens` and copy and paste the token in the cell below." ] }, { "cell_type": "code", "execution_count": null, "id": "f9808132-8b8b-4afc-9d75-1f68ef6feced", "metadata": { "tags": [] }, "outputs": [], "source": [ "from huggingface_hub import login\n", "login()" ] }, { "cell_type": "markdown", "id": "32873c17", "metadata": {}, "source": [ "#### Auxiliary functions to generate database and associated embeddings\n", "The following functions will load the CVS file and generate a simple in-memory database, then we will calculate the text embeddings using the Sentence-T5 model and store the data in a in-memory local vector database using chromadb" ] }, { "cell_type": "code", "execution_count": 54, "id": "ef05c398", "metadata": { "tags": [] }, "outputs": [], "source": [ "def gen_database():\n", " \"\"\"\n", " Reads customer review data from a CSV file and constructs a list of reviews.\n", "\n", " Returns:\n", " list: A list of customer reviews, where each review is a string. \n", " \"\"\"\n", " logging.info(f\"fn: gen_database()\")\n", " customer_reviews_csv_file = open(REVIEWS_TEXT_FILE)\n", " customer_reviews_csv = csv.reader(customer_reviews_csv_file)\n", " reviews = []\n", " _ = next(customer_reviews_csv)\n", " for review in customer_reviews_csv:\n", " reviews.append(review)\n", " return reviews\n", "\n", "def _calculate_emb(text):\n", " \"\"\"\n", " Calculates a text embedding (vector representation) using a SentenceTransformer model.\n", "\n", " Args:\n", " text (str): The input text for which to generate an embedding.\n", "\n", " Returns:\n", " list: A list containing the numerical components of the text embedding.\n", " \"\"\"\n", " logging.info(f\"fn: _calculate_emb()\")\n", " logging.info(f\"text: {text}\")\n", " model = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)\n", " embeddings = model.encode(text,device=DEVICE)\n", " for embedding in embeddings:\n", " vector = embedding.tolist()\n", " return vector\n", "\n", "\n", "def _calculate_emb_batch(text_list):\n", " \"\"\"\n", " Calculates a text embedding (vector representation) using a SentenceTransformer model.\n", "\n", " Args:\n", " text_list (str): The input text list for which to generate an embedding.\n", "\n", " Returns:\n", " vectors: A list containing the numerical components of the text embedding.\n", " \"\"\"\n", " logging.info(f\"fn: _calculate_emb()\")\n", " logging.info(f\"text: {text_list}\")\n", " model = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)\n", " embeddings = []\n", " for index, item in tqdm(enumerate(text_list)):\n", " embedding = model.encode(item,device=DEVICE).tolist()\n", " embeddings.append(embedding)\n", " del model\n", " return embeddings \n", "\n", "def gen_catalog_emb_space(text_list):\n", " \"\"\"\n", " Creates an embedding space (a ChromaDB collection) for product catalog items.\n", "\n", " Args:\n", " text_list (list): A list of text descriptions of customer reviews.\n", " \"\"\"\n", " logging.info(f\"fn: gen_catalog_emb_space()\")\n", " logging.info(f\"text_list: {text_list}\")\n", " os.makedirs(CHROMA_DIR, exist_ok = True)\n", " client = chromadb.PersistentClient(path=CHROMA_DIR)\n", " collection = client.get_or_create_collection(name=\"catalog\")\n", " vectors = _calculate_emb_batch(text_list)\n", " for index, item in tqdm(enumerate(text_list)):\n", " emb = vectors[index]\n", " collection.upsert(ids=f\"{index}\", embeddings=emb, documents=item) \n" ] }, { "cell_type": "code", "execution_count": 55, "id": "b25d2449", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "8004it [06:10, 21.59it/s]\n", "8004it [00:42, 186.27it/s]\n" ] } ], "source": [ "database = gen_database()\n", "gen_catalog_emb_space(database)" ] }, { "cell_type": "markdown", "id": "467b8fbe", "metadata": {}, "source": [ "#### RAG functions\n", "The following functions perform the retrieval augmented generation process" ] }, { "cell_type": "code", "execution_count": 64, "id": "cf176214-052d-4e4f-8b0b-17c98b70f79d", "metadata": { "tags": [] }, "outputs": [], "source": [ "def _retrieve(query):\n", " \"\"\"\n", " Retrieves relevant items from the catalog embedding space based on a query.\n", " Uses a similarity search function in '_search_closest_k'.\n", "\n", " Args:\n", " query (str): The user's search query.\n", " Returns:\n", " list: A list of matches (customer reviews ). \n", " \"\"\"\n", " logging.info(f\"fn: _retrieve()\")\n", " logging.info(f\"query: {query}\")\n", " matches = _search_closests_k(inference_input=query)\n", " return matches\n", "\n", "def _augmented_generation(retrieved_context,query):\n", " \"\"\"\n", " Generates a structured response (JSON) summarizing common themes from product reviews, \n", " leveraging a language model and retrieved context.\n", "\n", " Args:\n", " retrieved_context (list): A list of customer reviews relevant to the query.\n", " query (str): The original user's search query.\n", " Returns:\n", " str: A JSON-formatted string with the format:\n", " 'item_name: ITEM_NAME_HERE, \"common_themes\": [COMMON_THEMES_LIST_HERE]'\n", " \"\"\"\n", " logging.info(f\"fn: retrieved_context()\")\n", " logging.info(f\"retrieved_context: {retrieved_context}\")\n", " tokenizer = AutoTokenizer.from_pretrained(GENERATION_MODEL)\n", " model = AutoModelForCausalLM.from_pretrained(GENERATION_MODEL, device_map = DEVICE)\n", " generation_prompt = f\"\"\"\n", " You are a marketing analyst.\n", " You need to extract common themes from the product reviews.\n", " For example:\n", "\n", " PRODUCT : espresso\n", " REVIEW_LIST: [The espresso was very strong], [The espresso coffee was great and powerful], [The data beans espresso is very powerful]\n", " ANSWER:\n", " 'item_name: 'espresso', \"common_themes\": ['The espresso coffee is a very strong one']].\n", "\n", " PRODUCT : flat white\n", " REVIEW_LIST: [Amazing foamy coffee, loved it], [It was super smooth and nice], [The data beans latte is very good, the milk was soft and foamy and the flavour is great]\n", " ANSWER:\n", " 'item_name: 'flat white', \"common_themes\": ['The flat white is very smooth and the milk foamy']].\n", " \n", " - The item name and the review list are just examples, you can replace them with the actual product name and review list you want to analyze.\n", " - The common themes can be extracted from the review text by analyzing the words and phrases that are repeated frequently.\n", " - Reply only with ANSWER, nothing else.\n", " PRODUCT: {query} \n", " REVIEW_LIST: {retrieved_context}\n", " ANSWER:\n", " \"\"\"\n", " input_ids = tokenizer(generation_prompt, return_tensors=\"pt\").to(DEVICE)\n", " outputs = model.generate(**input_ids,max_new_tokens=1024)\n", " return tokenizer.decode(outputs[0])\n", " \n", "def _search_closests_k(inference_input,top_k=TOP_K_RETRIEVE):\n", " \"\"\"\n", " Searches the ChromaDB catalog embedding space and retrieves the top 'k' most similar items.\n", "\n", " Args:\n", " inference_input (str): The input query to use for the similarity search.\n", " top_k (int, optional): The number of top results to retrieve. \n", " Defaults to TOP_K_RETRIEVE.\n", " Returns:\n", " list: A list of retrieved contexts (customer reviews).\n", " \"\"\"\n", " logging.info(f\"fn: _search_closests_k()\")\n", " logging.info(f\"inference_input: {inference_input}\")\n", " logging.info(f\"top_k: {top_k}\")\n", " client = chromadb.PersistentClient(path=CHROMA_DIR)\n", " collection = client.get_or_create_collection(name=\"catalog\")\n", " inference_embedded = _calculate_emb(inference_input)\n", " top_similarities = collection.query(query_embeddings=inference_embedded,n_results=top_k)\n", " retrieved_context = top_similarities['documents']\n", " return retrieved_context\n", "\n", "def rag(query):\n", " \"\"\"\n", " Coordinates the retrieval and augmented generation process (acts as a higher-level function).\n", "\n", " Args:\n", " query (str): The user's search query.\n", " Returns:\n", " str: The JSON formatted output from the augmented generation.\n", " \"\"\"\n", " logging.info(f\"fn: rag()\")\n", " logging.info(f\"query: {query}\")\n", " retrieved_context = _retrieve(query)\n", " return _augmented_generation(retrieved_context,query)" ] }, { "cell_type": "markdown", "id": "d589a584", "metadata": {}, "source": [ "#### Inference\n", "Finally we call the RAG process providing the user query" ] }, { "cell_type": "code", "execution_count": 69, "id": "a368d077-0bae-4267-98ef-c374ac597b35", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1621d78a5f564597a8fb7e1cb2c7ff1c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "result = rag(query=[\"latte\"])" ] }, { "cell_type": "code", "execution_count": 70, "id": "deedc65a-2c6a-4bd8-a7a1-c7bed530267d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<bos>\n", " You are a marketing analyst.\n", " You need to extract common themes from the product reviews.\n", " For example:\n", "\n", " PRODUCT : espresso\n", " REVIEW_LIST: [The espresso was very strong], [The espresso coffee was great and powerful], [The data beans espresso is very powerful]\n", " ANSWER:\n", " 'item_name: 'espresso', \"common_themes\": ['The espresso coffee is a very strong one']].\n", "\n", " PRODUCT : flat white\n", " REVIEW_LIST: [Amazing foamy coffee, loved it], [It was super smooth and nice], [The data beans latte is very good, the milk was soft and foamy and the flavour is great]\n", " ANSWER:\n", " 'item_name: 'flat white', \"common_themes\": ['The flat white is very smooth and the milk foamy']].\n", " \n", " - The item name and the review list are just examples, you can replace them with the actual product name and review list you want to analyze.\n", " - The common themes can be extracted from the review text by analyzing the words and phrases that are repeated frequently.\n", " - Reply only with ANSWER, nothing else.\n", " PRODUCT: ['latte'] \n", " REVIEW_LIST: [['I ordered a latte with oat milk, It was delicious. There were lots of places to sit and the atmosphere was very relaxing. The service was fast and friendly and my coffee was hot and fresh.', 'The staff was very friendly and helpful. I ordered a latte and it was delicious. I will definitely be going back.', 'Iced vanilla latte was refreshing. Cozy seating and location is convenient. Quick and friendly service too.']]\n", " ANSWER:\n", " 'item_name: 'latte', \"common_themes\": ['The latte was delicious']].<eos>\n" ] } ], "source": [ "print(result)" ] } ], "metadata": { "environment": { "kernel": "python3", "name": "common-cu121.m115", "type": "gcloud", "uri": "gcr.io/deeplearning-platform-release/base-cu121:m115" }, "kernelspec": { "display_name": "Python 3 (Local)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }