In [None]:
##################################################################################
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###################################################################################

## Open source local RAG with `gemma` and `T5` models

### Notebook overview

This notebook shows how to implement a local RAG procedure with Open Source models to extract common review themes by product name.

It performs the following steps:

- **1 - Corpus embeddings generation:** This step generates embeddings for a CSV extract of `review_text` column the `data_beans.customer_review` table. This step uses the Google Sentence-T5 embedding model to projects the reviews into a 768 dimensional space.
    - Model details [here](https://arxiv.org/abs/2108.08877).
    - Embeddings vector are locally stored on a chromadb vector database. 

- **2 - Context retrieval:** This step generates the embedding (using the same T5-Sentece model) for the query and retrieves the top K most similar items from the vector database.

- **3 - Result generation:** This step uses the retrieved context in the previoup step and perform task resolution using Google Gemma 2-b instructioned tuned model.
     - Model details [here](https://arxiv.org/abs/2403.08295).



#### Architecture

![assets/gemma_local_rag.png](assets/gemma_local_rag.png)

*NOTE: As this notebook is running locally, inference performance will be determined by the underlying hardware (e.g. GPU)*

#### Installation
Install the following packages required to execute this notebook.

In [None]:
# Install the packages
! pip install chromadb==0.4.24
! pip install transformers==4.39.1
! pip install sentence-transformers==2.6.0
! pip install torch==2.2.1
! pip install huggingface-hub==0.22.0
! pip install ipywidgets

#### Import libraries and define variables
Import python libraries, definition of notebook variables and device (GPU) setup

In [1]:
import os
import logging
import chromadb
import csv
import torch
from tqdm import tqdm 
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer


REVIEWS_TEXT_FILE = "data/customer_reviews.csv"
CHROMA_DIR="chroma"
EMBEDDING_MODEL = "sentence-transformers/sentence-t5-xl"
GENERATION_MODEL = "google/gemma-2b-it"
TOP_K_RETRIEVE = 3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
logging.getLogger().setLevel(logging.ERROR)

#### Log-in into HugginFace to get the gemma model
This notebook uses gemma via the popular Hugginface `transformers` library.
You need a Hugginface account to download the model weights.

Gemma is located in a gated repo, so you also need to accept Gemma usage terms. [Hugginface gemma](https://huggingface.co/google/gemma-2b-it).

Once usage terms are accepted, navigate to `Profile > Settings > Acces Tokens` and copy and paste the token in the cell below.

In [None]:
from huggingface_hub import login
login()

#### Auxiliary functions to generate database and associated embeddings
The following functions will load the CVS file and generate a simple in-memory database, then we will calculate the text embeddings using the Sentence-T5 model and store the data in a in-memory local vector database using chromadb

In [54]:
def gen_database():
    """
    Reads customer review data from a CSV file and constructs a list of reviews.

    Returns:
        list: A list of customer reviews, where each review is a string. 
    """
    logging.info(f"fn: gen_database()")
    customer_reviews_csv_file = open(REVIEWS_TEXT_FILE)
    customer_reviews_csv = csv.reader(customer_reviews_csv_file)
    reviews = []
    _ = next(customer_reviews_csv)
    for review in customer_reviews_csv:
        reviews.append(review)
    return reviews

def _calculate_emb(text):
    """
    Calculates a text embedding (vector representation) using a SentenceTransformer model.

    Args:
        text (str): The input text for which to generate an embedding.

    Returns:
        list: A list containing the numerical components of the text embedding.
    """
    logging.info(f"fn: _calculate_emb()")
    logging.info(f"text: {text}")
    model = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
    embeddings = model.encode(text,device=DEVICE)
    for embedding in embeddings:
        vector = embedding.tolist()
    return vector


def _calculate_emb_batch(text_list):
    """
    Calculates a text embedding (vector representation) using a SentenceTransformer model.

    Args:
        text_list (str): The input text list for which to generate an embedding.

    Returns:
        vectors: A list containing the numerical components of the text embedding.
    """
    logging.info(f"fn: _calculate_emb()")
    logging.info(f"text: {text_list}")
    model = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
    embeddings = []
    for index, item in tqdm(enumerate(text_list)):
        embedding = model.encode(item,device=DEVICE).tolist()
        embeddings.append(embedding)
    del model
    return embeddings   

def gen_catalog_emb_space(text_list):
    """
    Creates an embedding space (a ChromaDB collection) for product catalog items.

    Args:
        text_list (list): A list of text descriptions of customer reviews.
    """
    logging.info(f"fn: gen_catalog_emb_space()")
    logging.info(f"text_list: {text_list}")
    os.makedirs(CHROMA_DIR, exist_ok = True)
    client = chromadb.PersistentClient(path=CHROMA_DIR)
    collection = client.get_or_create_collection(name="catalog")
    vectors = _calculate_emb_batch(text_list)
    for index, item in tqdm(enumerate(text_list)):
        emb = vectors[index]
        collection.upsert(ids=f"{index}", embeddings=emb, documents=item) 


In [55]:
database = gen_database()
gen_catalog_emb_space(database)

8004it [06:10, 21.59it/s]
8004it [00:42, 186.27it/s]


#### RAG functions
The following functions perform the retrieval augmented generation process

In [64]:
def _retrieve(query):
    """
    Retrieves relevant items from the catalog embedding space based on a query.
    Uses a similarity search function in '_search_closest_k'.

    Args:
        query (str): The user's search query.
    Returns:
        list: A list of matches (customer reviews ). 
    """
    logging.info(f"fn: _retrieve()")
    logging.info(f"query: {query}")
    matches = _search_closests_k(inference_input=query)
    return matches

def _augmented_generation(retrieved_context,query):
    """
    Generates a structured response (JSON) summarizing common themes from product reviews, 
    leveraging a language model and retrieved context.

    Args:
        retrieved_context (list): A list of customer reviews relevant to the query.
        query (str): The original user's search query.
    Returns:
        str: A JSON-formatted string with the format:
             'item_name: ITEM_NAME_HERE, "common_themes": [COMMON_THEMES_LIST_HERE]'
    """
    logging.info(f"fn: retrieved_context()")
    logging.info(f"retrieved_context: {retrieved_context}")
    tokenizer = AutoTokenizer.from_pretrained(GENERATION_MODEL)
    model = AutoModelForCausalLM.from_pretrained(GENERATION_MODEL, device_map = DEVICE)
    generation_prompt = f"""
            You are a marketing analyst.
            You need to extract common themes from the product reviews.
            For example:

            PRODUCT : espresso
            REVIEW_LIST: [The espresso was very strong], [The espresso coffee was great and powerful], [The data beans espresso is very powerful]
            ANSWER:
            'item_name: 'espresso', "common_themes": ['The espresso coffee is a very strong one']].

            PRODUCT : flat white
            REVIEW_LIST: [Amazing foamy coffee, loved it], [It was super smooth and nice], [The data beans latte is very good, the milk was soft and foamy and the flavour is great]
            ANSWER:
            'item_name: 'flat white', "common_themes": ['The flat white is very smooth and the milk foamy']].
      
            - The item name and the review list are just examples, you can replace them with the actual product name and review list you want to analyze.
            - The common themes can be extracted from the review text by analyzing the words and phrases that are repeated frequently.
            - Reply only with ANSWER, nothing else.
            PRODUCT: {query}  
            REVIEW_LIST:  {retrieved_context}
            ANSWER:
           """
    input_ids = tokenizer(generation_prompt, return_tensors="pt").to(DEVICE)
    outputs = model.generate(**input_ids,max_new_tokens=1024)
    return tokenizer.decode(outputs[0])
    
def _search_closests_k(inference_input,top_k=TOP_K_RETRIEVE):
    """
    Searches the ChromaDB catalog embedding space and retrieves the top 'k' most similar items.

    Args:
        inference_input (str): The input query to use for the similarity search.
        top_k (int, optional): The number of top results to retrieve. 
                               Defaults to TOP_K_RETRIEVE.
    Returns:
        list: A list of retrieved contexts (customer reviews).
    """
    logging.info(f"fn: _search_closests_k()")
    logging.info(f"inference_input: {inference_input}")
    logging.info(f"top_k: {top_k}")
    client = chromadb.PersistentClient(path=CHROMA_DIR)
    collection = client.get_or_create_collection(name="catalog")
    inference_embedded = _calculate_emb(inference_input)
    top_similarities =  collection.query(query_embeddings=inference_embedded,n_results=top_k)
    retrieved_context = top_similarities['documents']
    return retrieved_context

def rag(query):
    """
    Coordinates the retrieval and augmented generation process (acts as a higher-level function).

    Args:
        query (str): The user's search query.
    Returns:
        str:  The JSON formatted output from the augmented generation.
    """
    logging.info(f"fn: rag()")
    logging.info(f"query: {query}")
    retrieved_context = _retrieve(query)
    return _augmented_generation(retrieved_context,query)

#### Inference
Finally we call the RAG process providing the user query

In [69]:
result = rag(query=["latte"])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [70]:
print(result)

<bos>
            You are a marketing analyst.
            You need to extract common themes from the product reviews.
            For example:

            PRODUCT : espresso
            REVIEW_LIST: [The espresso was very strong], [The espresso coffee was great and powerful], [The data beans espresso is very powerful]
            ANSWER:
            'item_name: 'espresso', "common_themes": ['The espresso coffee is a very strong one']].

            PRODUCT : flat white
            REVIEW_LIST: [Amazing foamy coffee, loved it], [It was super smooth and nice], [The data beans latte is very good, the milk was soft and foamy and the flavour is great]
            ANSWER:
            'item_name: 'flat white', "common_themes": ['The flat white is very smooth and the milk foamy']].
      
            - The item name and the review list are just examples, you can replace them with the actual product name and review list you want to analyze.
            - The common themes can be extracted fro