gemini/use-cases/retrieval-augmented-generation/retail_warranty_claim_chatbot.ipynb (1,641 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Uq-CsHxA7hjZ"
},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9695203a746c"
},
"source": [
"# Building a Multimodal Chatbot for Warranty Claims using Gemini and Vector Search in Vertex AI\n",
"\n",
"<table align=\"left\">\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/retrieval-augmented-generation/retail_warranty_claim_chatbot.ipynb\">\n",
" <img src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Google Colaboratory logo\"><br> Run in Colab\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fuse-cases%2Fretrieval-augmented-generation%2Fretail_warranty_claim_chatbot.ipynb\">\n",
" <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Run in Colab Enterprise\n",
" </a>\n",
" </td> \n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/retrieval-augmented-generation/retail_warranty_claim_chatbot.ipynb\">\n",
" <img src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\"><br> View on GitHub\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/use-cases/retrieval-augmented-generation/retail_warranty_claim_chatbot.ipynb\">\n",
" <img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
" </a>\n",
" </td>\n",
"</table>\n",
"\n",
"<div style=\"clear: both;\"></div>\n",
"\n",
"<b>Share to:</b>\n",
"\n",
"<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/retrieval-augmented-generation/retail_warranty_claim_chatbot.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/retrieval-augmented-generation/retail_warranty_claim_chatbot.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/retrieval-augmented-generation/retail_warranty_claim_chatbot.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/retrieval-augmented-generation/retail_warranty_claim_chatbot.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/retrieval-augmented-generation/retail_warranty_claim_chatbot.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
"</a> "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-3K6twUOR-jD"
},
"source": [
"| | |\n",
"|-|-|\n",
"|Author(s) | [Zachary Thorman](https://github.com/zthor5), [Charles Elliott](https://github.com/charleselliott) |"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "INIA40pMSYOn"
},
"source": [
"## Overview\n",
"\n",
"This notebook walks through the process to build a warranty claims chatbot that utilizes Vector Search and the Gemini API in Vertex AI in Google Cloud. For the purposes of this notebook, we will utilize a ficticious shoe startup called [AquaStride](https://storage.googleapis.com/github-repo/generative-ai/gemini/use-cases/rag/warranty-claim-chatbot/aquastride-company-overview.pdf).\n",
"\n",
" - For teaching purposes, you'll ingest the sample data by converting PDFs -> Images -> Text -> Embeddings -> Vector DB.\n",
" - In this notebook, you will create a custom RAG implementation, deployed on Vector Search. You can also use other managed services like [Vertex AI Search](https://cloud.google.com/enterprise-search?hl=en) as a vector database.\n",
" - We also used [Function Calling](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling) in the Gemini API to handle driving the user intents towards their intended functions.\n",
"\n",
"The sample code shown in this notebook originally appeared in the [Building out code pipelines for your Gen AI customer service app](https://www.youtube.com/live/Zm255g3URpw?feature=shared&t=2845) session at the [Google Startup School](https://startup.google.com/programs/startup-school/) on May 28th, 2024.\n",
"\n",
"<img src=\"https://storage.googleapis.com/github-repo/generative-ai/gemini/use-cases/rag/warranty-claim-chatbot/user-flow-diagram.png\" width=\"70%\">"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r11Gu7qNgx1p"
},
"source": [
"## Getting Started\n",
"\n",
"In this section, you will install the necessary dependencies and define the Google Cloud project where you want to connect to Vertex AI."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "No17Cw5hgx12"
},
"source": [
"### Install Vertex AI SDK and other required packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Hfn_UeXyTHaH"
},
"outputs": [],
"source": [
"%pip install --upgrade -q pymupdf gradio google-cloud-aiplatform langchain_google_vertexai pillow gradio regex langchain==0.1.20"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jXHfaVS66_01"
},
"source": [
"### Import libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lslYAvw37JGQ"
},
"outputs": [],
"source": [
"import base64\n",
"from datetime import datetime\n",
"\n",
"# File system operations and displaying images\n",
"import os\n",
"\n",
"## Initialize the Colab Library & sys\n",
"import sys\n",
"\n",
"# Import utility functions for timing and file handling\n",
"import time\n",
"\n",
"# Libraries for downloading files, data manipulation, and creating a user interface\n",
"import uuid\n",
"\n",
"from PIL import Image as PIL_Image\n",
"import fitz\n",
"\n",
"# Initialize Vertex AI libraries for working with generative s\n",
"from google.cloud import aiplatform\n",
"import gradio as gr\n",
"\n",
"# Import LangChain components\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain_community.document_loaders import DataFrameLoader\n",
"import pandas as pd\n",
"import regex as re\n",
"\n",
"# Initialize Vertex AI\n",
"import vertexai\n",
"from vertexai.generative_models import (\n",
" FunctionDeclaration,\n",
" GenerativeModel,\n",
" Image,\n",
" Part,\n",
" Tool,\n",
")\n",
"from vertexai.language_models import TextEmbeddingModel\n",
"import vertexai.preview.generative_models as generative_models\n",
"from vertexai.preview.generative_models import ToolConfig"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j7UyNVSiyQ96"
},
"source": [
"### Restart runtime\n",
"\n",
"To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n",
"\n",
"The restart might take a minute or longer. After it's restarted, continue to the next step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YmY9HVVGSBW5"
},
"outputs": [],
"source": [
"import IPython\n",
"\n",
"app = IPython.Application.instance()\n",
"app.kernel.do_shutdown(True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EXQZrM5hQeKb"
},
"source": [
"<div class=\"alert alert-block alert-warning\">\n",
"<b>โ ๏ธ Wait for the kernel to finish restarting before you continue. โ ๏ธ</b>\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dmWOrTJ3gx13"
},
"source": [
"### Authenticate your notebook environment (Colab only)\n",
"\n",
"If you are running this notebook on Google Colab, run the cell below to authenticate your environment.\n",
"\n",
"This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NyKGtVQjgx13"
},
"outputs": [],
"source": [
"# Additional authentication is required for Google Colab\n",
"if \"google.colab\" in sys.modules:\n",
" # Authenticate user to Google Cloud\n",
" from google.colab import auth\n",
"\n",
" auth.authenticate_user()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DF4l8DTdWgPY"
},
"source": [
"### Define Google Cloud project information, initialize Vertex AI, and add Secrets\n",
"\n",
"To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
"\n",
"Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Nqwi-5ufWp_B"
},
"outputs": [],
"source": [
"# Utilizing Secrets to retrieve sensitive information\n",
"# You can add your own projectID and location to run in your environment.\n",
"\n",
"PROJECT_ID = \"[your-project-id]\" # @param {type:\"string\"}\n",
"LOCATION = \"us-central1\" # @param {type:\"string\"}\n",
"\n",
"\n",
"vertexai.init(project=PROJECT_ID, location=LOCATION)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lP1EMJL4ddcA"
},
"source": [
"### Initializing Gemini and Text Embedding models\n",
"\n",
"Here we initialize the models that will be used for embeddings & answering questions against the PDFs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sNOqAHc3dbDf"
},
"outputs": [],
"source": [
"# Defines the Generative Models Configuration\n",
"generation_config = {\n",
" \"max_output_tokens\": 8192,\n",
" \"temperature\": 0,\n",
" \"top_p\": 0.95,\n",
"}\n",
"\n",
"# Loading Gemini Model\n",
"multimodal_model = GenerativeModel(\n",
" \"gemini-2.0-flash\", generation_config=generation_config\n",
")\n",
"\n",
"# Initializing embedding model\n",
"text_embedding_model = TextEmbeddingModel.from_pretrained(\"text-embedding-005\")\n",
"\n",
"# Download backup blank file to use if needed when no results (Not Required for RAG)\n",
"! wget -O no-matching-pages.png https://storage.googleapis.com/github-repo/generative-ai/gemini/use-cases/rag/warranty-claim-chatbot/no-matching-pages.png"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ECB4ixV6VF0B"
},
"source": [
"# Helper Functions for RAG\n",
"\n",
"\n",
"In this section, you will ingest sample data by converting PDFs -> Images -> Text -> Embeddings -> Vector DB.\n",
"\n",
"The following cells define helper functions that will be used in the following sections. Feel free to run the group of collapsed cells at once or review at your discretion."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a0nUDhV2ZvUU"
},
"source": [
"### Create and clean images folder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Pr7J1V06ZtU3"
},
"outputs": [],
"source": [
"# Pass The folder path for storing the images\n",
"\n",
"\n",
"def create_clean_image_folder(Image_Path):\n",
" # Create the directory if it doesn't exist\n",
" if not os.path.exists(Image_Path):\n",
" os.makedirs(Image_Path)\n",
" image_star = Image_Path + \"*\"\n",
" !rm -rf {image_star}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XPQ5eqaTcLP7"
},
"source": [
"### Split PDF to images and extract data using Gemini\n",
"\n",
"This module processes a set of images, extracting text and tabular data using a multimodal model (Gemini).\n",
"It handles potential errors, stores the extracted information in a DataFrame, and saves the results to a CSV file.\n",
"\n",
"You can modify this approach in a number of ways, such as to use [Document AI](https://cloud.google.com/blog/products/ai-machine-learning/document-ai-custom-extractor-powered-by-generative-ai-is-now-ga) for OCR Parsing. Feel free to try alternatives!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pSKL45w4O5ob"
},
"outputs": [],
"source": [
"def split_pdf_extract_data(pdfList, folder_uri):\n",
" # To get better resolution\n",
" zoom_x = 2.0 # horizontal zoom\n",
" zoom_y = 2.0 # vertical zoom\n",
" mat = fitz.Matrix(zoom_x, zoom_y) # zoom factor 2 in each dimension\n",
"\n",
" for indiv_Pdf in pdfList:\n",
" doc = fitz.open(indiv_Pdf) # open document\n",
" for page in doc: # iterate through the pages\n",
" pix = page.get_pixmap(matrix=mat) # render page to an image\n",
" outpath = f\"{folder_uri}{indiv_Pdf}_{page.number}.png\"\n",
" pix.save(outpath) # store image as a PNG\n",
"\n",
" # Define the path where images are located\n",
" image_names = os.listdir(folder_uri)\n",
" Max_images = len(image_names)\n",
"\n",
" # Create empty lists to store image information\n",
" page_source = []\n",
" page_content = []\n",
" page_id = []\n",
"\n",
" p_id = 0 # Initialize image ID counter\n",
" rest_count = 0 # Initialize counter for error handling\n",
"\n",
" while p_id < Max_images:\n",
" try:\n",
" # Construct the full path to the current image\n",
" image_path = folder_uri + image_names[p_id]\n",
"\n",
" # Load the image\n",
" image = Image.load_from_file(image_path)\n",
"\n",
" # Generate prompts for text and table extraction\n",
" prompt_text = \"Extract all text content in the image\"\n",
" prompt_table = (\n",
" \"Detect table in this image. Extract content maintaining the structure\"\n",
" )\n",
" prompt_image = \"Detect images in this image. Extract content in the form of alternative text or subtitles to each sub-image\"\n",
"\n",
" # Extract text using your multimodal model\n",
" contents = [image, prompt_text]\n",
" response = multimodal_model.generate_content(contents)\n",
" text_content = response.text\n",
"\n",
" # Extract table using your multimodal model\n",
" contents = [image, prompt_table]\n",
" response = multimodal_model.generate_content(contents)\n",
" table_content = response.text\n",
"\n",
" # Extract information from images (i.e. Subtitle / Alternative text). | Currently Disabled\n",
" # contents = [image, prompt_image]\n",
" # response = multimodal_model.generate_content(contents)\n",
" # image_content = response.text\n",
"\n",
" # Log progress and store results\n",
" print(f\"processed image no: {p_id}\")\n",
" page_source.append(image_path)\n",
" page_content.append(\n",
" text_content + \"\\n\" + table_content\n",
" ) # + \"\\n\" + image_content)\n",
" page_id.append(p_id)\n",
" p_id += 1\n",
"\n",
" except Exception as err:\n",
" # Handle errors during processing\n",
" print(err)\n",
" print(\"Taking Some Rest\")\n",
" time.sleep(\n",
" 12\n",
" ) # Pause execution for 12 second due to default Quota for Vertex AI\n",
" rest_count += 1\n",
" if rest_count == 5: # Limit consecutive error handling\n",
" rest_count = 0\n",
" print(f\"Cannot process image no: {image_path}\")\n",
" p_id += 1 # Move to the next image\n",
"\n",
" # Create a DataFrame to store extracted information\n",
" df = pd.DataFrame(\n",
" {\"page_id\": page_id, \"page_source\": page_source, \"page_content\": page_content}\n",
" )\n",
" del page_id, page_source, page_content # Conserve memory\n",
" df.head() # Preview the DataFrame\n",
"\n",
" return df"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FrqESrwnzUXW"
},
"source": [
"### Create the chunks and embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7e087QRTVNbz"
},
"outputs": [],
"source": [
"def generate_text_embedding(text) -> list:\n",
" \"\"\"Text embedding with a Large Language Model.\"\"\"\n",
" embeddings = text_embedding_model.get_embeddings([text])\n",
" vector = embeddings[0].values\n",
" return vector"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kHtZhmkYVCFP"
},
"outputs": [],
"source": [
"# Returns a chunked embeddings dataframe\n",
"\n",
"\n",
"def create_chunked_embeddings(df):\n",
" # Create a DataFrameLoader to prepare data for LangChain\n",
" loader = DataFrameLoader(df, page_content_column=\"page_content\")\n",
"\n",
" # Load documents from the 'page_content' column of your DataFrame\n",
" documents = loader.load()\n",
"\n",
" # Log the number of documents loaded\n",
" print(f\"# of documents loaded (pre-chunking) = {len(documents)}\")\n",
"\n",
" # Create a text splitter to divide documents into smaller chunks\n",
" text_splitter = CharacterTextSplitter(\n",
" chunk_size=10000, # Target size of approximately 10000 characters per chunk\n",
" chunk_overlap=200, # overlap between chunks\n",
" )\n",
"\n",
" # Split the loaded documents\n",
" doc_splits = text_splitter.split_documents(documents)\n",
"\n",
" # Add a 'chunk' ID to each document split's metadata for tracking\n",
" for idx, split in enumerate(doc_splits):\n",
" split.metadata[\"chunk\"] = idx\n",
"\n",
" # Log the number of documents after splitting\n",
" print(f\"# of documents = {len(doc_splits)}\")\n",
"\n",
" texts = [doc.page_content for doc in doc_splits]\n",
" text_embeddings_list = []\n",
" id_list = []\n",
" page_source_list = []\n",
" for doc in doc_splits:\n",
" id = uuid.uuid4()\n",
" text_embeddings_list.append(generate_text_embedding(doc.page_content))\n",
" id_list.append(str(id))\n",
" page_source_list.append(doc.metadata[\"page_source\"])\n",
" time.sleep(12) # So that we don't run into Quota Issue\n",
"\n",
" # Creating a dataframe of ID, embeddings, page_source and text\n",
" embedding_df = pd.DataFrame(\n",
" {\n",
" \"id\": id_list,\n",
" \"embedding\": text_embeddings_list,\n",
" \"page_source\": page_source_list,\n",
" \"text\": texts,\n",
" }\n",
" )\n",
" embedding_df.head()\n",
" return embedding_df"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6pu1a3zjfQ0D"
},
"source": [
"### Save the embeddings in a JSON file\n",
"To load the embeddings into Vector Search, we need to save them in JSON files with JSONL format. See more information in the docs at [Input data format and structure](https://cloud.google.com/vertex-ai/docs/matching-engine/match-eng-setup/format-structure#data-file-formats).\n",
"\n",
"First, export the `id` and `embedding` columns from the DataFrame in JSONL format, and save it.\n",
"\n",
"Then, create a new Cloud Storage bucket and copy the file to it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mSqVnRpqh7Yc"
},
"outputs": [],
"source": [
"def create_json_file(embedding_df, RAG_unique_identifier):\n",
" # save id and embedding as a json file\n",
" json_file_name = RAG_unique_identifier + \".json\"\n",
" jsonl_string = embedding_df[[\"id\", \"embedding\"]].to_json(\n",
" orient=\"records\", lines=True\n",
" )\n",
" with open(json_file_name, \"w\") as f:\n",
" f.write(jsonl_string)\n",
"\n",
" # Show the first few lines of the json file\n",
" #! head -n 3 {json_file_name}\n",
" return json_file_name"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CzwDWJfzAk3n"
},
"outputs": [],
"source": [
"def upload_file_to_gcs(json_file_name, bucket_location):\n",
" # Generates a unique ID for session\n",
" UID = datetime.now().strftime(\"%m%d%H%M%S\")\n",
" # Creates a GCS bucket\n",
" BUCKET_URI = f\"gs://{bucket_location}--{UID}\"\n",
" ! gsutil mb -l $LOCATION -p {PROJECT_ID} {BUCKET_URI}\n",
" ! gsutil cp {json_file_name} {BUCKET_URI}\n",
" return BUCKET_URI"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xxdbjKw1XDxl"
},
"source": [
"### Create an index in Vector Search\n",
"\n",
"Now it's ready to load the embeddings to Vector Search. Its APIs are available under the [aiplatform](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform) package of the SDK."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xpMUXqWQ75uu"
},
"source": [
"Create an [MatchingEngineIndex](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.MatchingEngineIndex) with its `create_tree_ah_index` function (Matching Engine is the previous name of Vector Search)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kKDw5VXMkXb3"
},
"outputs": [],
"source": [
"def create_index(vec_search_index_name, bucket_location):\n",
" return aiplatform.MatchingEngineIndex.create_tree_ah_index(\n",
" display_name=f\"{vec_search_index_name}\",\n",
" contents_delta_uri=bucket_location,\n",
" dimensions=768,\n",
" approximate_neighbors_count=20,\n",
" distance_measure_type=\"DOT_PRODUCT_DISTANCE\",\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2rFam_w9U0dI"
},
"source": [
"By calling the `create_tree_ah_index` function, it starts building an Index. This will take under a few minutes if the dataset is small, otherwise about 50 minutes or more depending on the size of the dataset.\n",
"\n",
"You can check status of the index creation on [the Vector Search Console > INDEXES tab](https://console.cloud.google.com/vertex-ai/matching-engine/indexes).\n",
"\n",
"\n",
"\n",
"---\n",
"\n",
"See [this document](https://cloud.google.com/vertex-ai/docs/vector-search/create-manage-index) for more details on creating your Index and the parameters."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3fDs-FDS789e"
},
"source": [
"### Create an index endpoint and deploy the index\n",
"\n",
"To use the Index, you need to create an [Index Endpoint](https://cloud.google.com/vertex-ai/docs/vector-search/deploy-index-public). It works as a server instance accepting query requests for your Index.\n",
"\n",
"You can view your public endpoints [on Google Cloud's Vertex Endpoints](https://console.cloud.google.com/vertex-ai/matching-engine/index-endpoints)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "peH6WpSj789m"
},
"outputs": [],
"source": [
"def create_Index_Endpoint(my_index, vec_search_index_name):\n",
" my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(\n",
" display_name=f\"{vec_search_index_name}\",\n",
" public_endpoint_enabled=True,\n",
" )\n",
"\n",
" DEPLOYED_INDEX_NAME = vec_search_index_name.replace(\n",
" \"-\", \"_\"\n",
" ) # Can't have '-' in deployment name, only alphanumeric and _ allowed\n",
" UID = datetime.now().strftime(\"%m%d%H%M%S\")\n",
" DEPLOYED_INDEX_ID = f\"{DEPLOYED_INDEX_NAME}_{UID}\"\n",
" # deploy the Index to the Index Endpoint\n",
" my_index_endpoint.deploy_index(index=my_index, deployed_index_id=DEPLOYED_INDEX_ID)\n",
"\n",
" return my_index_endpoint, DEPLOYED_INDEX_ID"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5YtepoMX789m"
},
"source": [
"This demo utilizes a [Public Endpoint](https://cloud.google.com/vertex-ai/docs/vector-search/setup/setup#choose-endpoint) and does not support [Virtual Private Cloud (VPC)](https://cloud.google.com/vpc/docs/private-services-access). Unless you have a specific requirement for VPC, it is recommended to use a Public Endpoint.\n",
"\n",
"Despite the term \"public\" in its name, it does not imply open access to the public internet. Without explicit IAM permissions, no one can access the endpoint."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xu9ZmWcpXQ55"
},
"source": [
"If it is the first time to deploy an Index to an Index Endpoint, it will take around 25 minutes to automatically build and initiate the backend for it. After the first deployment, it will finish in seconds. To see the status of the index deployment, open [the Vector Search Console > INDEX ENDPOINTS tab](https://console.cloud.google.com/vertex-ai/matching-engine/index-endpoints) and click the Index Endpoint."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fM3a9RQx4pQO"
},
"source": [
"### Ask Questions to the PDF\n",
"This code snippet establishes a question-answering (QA) system. It leverages a vector search engine to find relevant information from a dataset and then uses the LLM to generate and refine the final answer to a user's query."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dRme7I2hhqC8"
},
"outputs": [],
"source": [
"def Test_LLM_Response(txt):\n",
" \"\"\"\n",
" Determines whether a given text response generated by an LLM indicates a lack of information.\n",
"\n",
" Args:\n",
" txt (str): The text response generated by the LLM.\n",
"\n",
" Returns:\n",
" bool: True if the LLM's response suggests it was able to generate a meaningful answer,\n",
" False if the response indicates it could not find relevant information.\n",
"\n",
" This function works by presenting a formatted classification prompt to the LLM (`gemini_pro_model`).\n",
" The prompt includes the original text and specific categories indicating whether sufficient information was available.\n",
" The function analyzes the LLM's classification output to make the determination.\n",
" \"\"\"\n",
"\n",
" classification_prompt = f\"\"\" Classify the text as one of the following categories:\n",
" -Information Present\n",
" -Information Not Present\n",
" Text=The provided context does not contain information.\n",
" Category:Information Not Present\n",
" Text=I cannot answer this question from the provided context.\n",
" Category:Information Not Present\n",
" Text:{txt}\n",
" Category:\"\"\"\n",
" classification_response = multimodal_model.generate_content(\n",
" classification_prompt\n",
" ).text\n",
"\n",
" if \"Not Present\" in classification_response:\n",
" return False # Indicates that the LLM couldn't provide an answer\n",
" else:\n",
" return True # Suggests the LLM generated a meaningful response\n",
"\n",
"\n",
"def get_prompt_text(question, context):\n",
" \"\"\"\n",
" Generates a formatted prompt string suitable for a language model, combining the provided question and context.\n",
"\n",
" Args:\n",
" question (str): The user's original question.\n",
" context (str): The relevant text to be used as context for the answer.\n",
"\n",
" Returns:\n",
" str: A formatted prompt string with placeholders for the question and context, designed to guide the language model's answer generation.\n",
" \"\"\"\n",
" prompt = \"\"\"\n",
" Answer the question using the context below. Respond with only information from the text provided\n",
" Question: {question}\n",
" Context : {context}\n",
" \"\"\".format(\n",
" question=question, context=context\n",
" )\n",
" return prompt\n",
"\n",
"\n",
"def get_answer(\n",
" embedding_df, my_index_endpoint, DEPLOYED_INDEX_ID, query=\"No Query was provided.\"\n",
"):\n",
" \"\"\"\n",
" Retrieves an answer to a provided query using multimodal RAG.\n",
"\n",
" This function leverages a vector search system to find relevant text documents from a\n",
" pre-indexed store of multimodal data. Then, it uses a large language model (LLM) to generate\n",
" an answer, using the retrieved documents as context.\n",
"\n",
" Args:\n",
" query (str): The user's original query.\n",
"\n",
" Returns:\n",
" dict: A dictionary containing the following keys:\n",
" * 'result' (str): The LLM-generated answer.\n",
" * 'neighbor_index' (int): The index of the most relevant document used for generation\n",
" (for fetching image path).\n",
"\n",
" Raises:\n",
" RuntimeError: If no valid answer could be generated within the specified search attempts.\n",
" \"\"\"\n",
"\n",
" neighbor_index = 0 # Initialize index for tracking the most relevant document\n",
" answer_found_flag = 0 # Flag to signal if an acceptable answer is found\n",
" result = \"\" # Initialize the answer string\n",
" # Use a default image if the reference is not found\n",
" page_source = \"./no-matching-pages.png\" # Initialize the blank image\n",
" query_embeddings = generate_text_embedding(\n",
" query\n",
" ) # Generate embeddings for the query\n",
"\n",
" response = my_index_endpoint.find_neighbors(\n",
" deployed_index_id=DEPLOYED_INDEX_ID,\n",
" queries=[query_embeddings],\n",
" num_neighbors=5,\n",
" ) # Retrieve up to 5 relevant documents from the vector store\n",
"\n",
" while answer_found_flag == 0 and neighbor_index < 4:\n",
" context = embedding_df[\n",
" embedding_df[\"id\"] == response[0][neighbor_index].id\n",
" ].text.values[\n",
" 0\n",
" ] # Extract text context from the relevant document\n",
"\n",
" prompt = get_prompt_text(\n",
" query, context\n",
" ) # Create a prompt using the question and context\n",
" result = multimodal_model.generate_content(\n",
" prompt\n",
" ).text # Generate an answer with the LLM\n",
"\n",
" if Test_LLM_Response(result):\n",
" answer_found_flag = 1 # Exit loop when getting a valid response\n",
" else:\n",
" neighbor_index += (\n",
" 1 # Try the next retrieved document if the answer is unsatisfactory\n",
" )\n",
"\n",
" if answer_found_flag == 1:\n",
" page_source = embedding_df[\n",
" embedding_df[\"id\"] == response[0][neighbor_index].id\n",
" ].page_source.values[\n",
" 0\n",
" ] # Extract image_path from the relevant document\n",
" return result, page_source"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "igIvg1yMtJWy"
},
"source": [
"# Create a RAG endpoint\n",
"\n",
"In this section, you will load sample data into a Vector Search endpoint. In this example you'll be using PDFs files that contain [a company overview](https://storage.googleapis.com/github-repo/generative-ai/gemini/use-cases/rag/warranty-claim-chatbot/aquastride-company-overview.pdf) and a [list of products SKUs](https://storage.googleapis.com/github-repo/generative-ai/gemini/use-cases/rag/warranty-claim-chatbot/aquastride-sku-sn-database.pdf).\n",
"\n",
"It is **recommended** for production workloads to use a managed database for improved performance and efficiency."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "85FJ7Fisx_Po"
},
"source": [
"## Create RAG Function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8PCa3GqAtHnz"
},
"outputs": [],
"source": [
"def create_RAG(RAG_unique_identifier, rag_list_pdfs):\n",
" # Creates a Unique folder for the segmented PDF images. (Each page of the PDF is converted into a .PNG)\n",
" folder_url = f\"./{RAG_unique_identifier}_images/\"\n",
" create_clean_image_folder(folder_url)\n",
"\n",
" # Creates the embeddings dataframe of the PDF Images.\n",
" company_dataframe = split_pdf_extract_data(rag_list_pdfs, folder_url)\n",
" company_embeddings_dataframe = create_chunked_embeddings(company_dataframe)\n",
"\n",
" # Creates unique names for the Google Cloud Vector Search & GCS Bucket URL.\n",
" vec_search_index_name = f\"vec-search-index-{RAG_unique_identifier}\"\n",
" bucket_name = f\"vec-search-bucket-{RAG_unique_identifier}\"\n",
"\n",
" # Uploads the embeddings to GCS as a JSON file.\n",
" json_file_name = create_json_file(\n",
" company_embeddings_dataframe, RAG_unique_identifier\n",
" )\n",
" bucket_location = upload_file_to_gcs(json_file_name, bucket_name)\n",
"\n",
" # This function may take up to 25 minutes to run to deploy the custom Vector Search to a Public Endpoint.\n",
" index = create_index(vec_search_index_name, bucket_location)\n",
" my_index_endpoint, index_id = create_Index_Endpoint(index, vec_search_index_name)\n",
"\n",
" # Create a reusable Object for each Rag Model to call upon\n",
" RAG_model_info = {\n",
" \"bucket_uri\": bucket_location,\n",
" \"index\": index,\n",
" \"embeddings_dataframe\": company_embeddings_dataframe,\n",
" \"index_id\": index_id,\n",
" \"my_index_endpoint\": my_index_endpoint,\n",
" }\n",
"\n",
" return RAG_model_info"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9XBy4dKWvLgb"
},
"source": [
"## Testing the RAG performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SEQa8WoMvLFi"
},
"outputs": [],
"source": [
"# Download your PDFs here using the wget command.\n",
"! wget -q -O aquastride_company.pdf 'https://storage.googleapis.com/github-repo/generative-ai/gemini/use-cases/rag/warranty-claim-chatbot/aquastride-company-overview.pdf'\n",
"! wget -q -O aquastride_DB.pdf 'https://storage.googleapis.com/github-repo/generative-ai/gemini/use-cases/rag/warranty-claim-chatbot/aquastride-sku-sn-database.pdf'\n",
"\n",
"# Needs to be lowercase characters with no spaces; e.g. \"test\", \"aquastride\".\n",
"RAG_unique_identifier = \"aquastride\" # @param {type: \"string\"}\n",
"\n",
"# List the PDFs to be processed via the RAG Endpoint.\n",
"pdf_list = [\"aquastride_company.pdf\", \"aquastride_DB.pdf\"]\n",
"\n",
"# Creates the RAG model endpoint on Vertex AI Vector Search.\n",
"rag_info = create_RAG(RAG_unique_identifier, pdf_list)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DFIuF4ki7ajE"
},
"outputs": [],
"source": [
"# Provide a Query to test the deployed endpoint.\n",
"# Highly recommended to use a call to a Database (i.e. Cloud SQL) with the extracted Serial number.\n",
"query = \"Provided the Serial_No (CZE5F6G7) and SKU (DepthStrider_23_Red_Norm), Determine the cx_name who purchased this serial number.\\\\n Output the Owner (cx_name) and the address (cx_address) in this format: \\\\nThank you [cx_name] for your purchase! We have you on file at [cx_address].\" # @param {type: \"string\"}\n",
"\n",
"# Responds with the result of the query against the RAG endpoint & its source.\n",
"result, page_source = get_answer(\n",
" rag_info[\"embeddings_dataframe\"],\n",
" rag_info[\"my_index_endpoint\"],\n",
" rag_info[\"index_id\"],\n",
" query,\n",
")\n",
"\n",
"# If the endpoint returns irrelevant context to the LLM, respond with the below.\n",
"if page_source == \"./no-matching-pages.png\":\n",
" result = (\n",
" \"I could not find your answer within the Data. Can you rephrase your question?\"\n",
" )\n",
"\n",
"# Print the results and it's page source.\n",
"print(f\"Response: {result}\\nPage Source: {page_source}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C0nOkTLPf0x4"
},
"source": [
"# Implement application logic and function calling\n",
"\n",
"In this section, you will implement logic with [Gemini Function Calling](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling) to handle tasks related to warranty claim support such as extracting information from images of shoe tags or inspecting pictures of shoes for physical damage."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XiAPlcx9JQCa"
},
"source": [
"## Initialize and configure the Gemini model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "koLU4OYFt7rt"
},
"outputs": [],
"source": [
"image_determination_prompt = \"\"\"\n",
"You will be provided with an image. Analyze the image and perform the following:\n",
"\n",
"**Image Classification:**\n",
"\n",
"* **Shoe Tag:** If the image is a shoe tag, extract the following information and generate a corresponding JSON schema:\n",
"\n",
" ```json\n",
" {\n",
" \"brand\": \"[Brand Name/Website]\",\n",
" \"product\": \"[Product Name/SKU]\",\n",
" \"serialNumber\": \"[Serial Number]\",\n",
" \"sizing\": {\n",
" \"us\": \"[US Size]\",\n",
" \"uk\": \"[UK Size]\",\n",
" \"eur\": \"[EUR Size]\",\n",
" \"chn\": \"[CHN Size]\"\n",
" },\n",
" \"madeIn\": \"[Manufacturing Location]\"\n",
" }\n",
" ```\n",
"\n",
"* **Damaged Shoe:** If the image shows a shoe with visible damage, assess the damage and generate a JSON schema for damage reporting:\n",
"\n",
" ```json\n",
" {\n",
" \"damagedAreas\": [\"[Area 1]\", \"[Area 2]\", ...],\n",
" \"damageType\": \"[Damage Type]\",\n",
" \"severity\": \"[Severity Level]\",\n",
" \"additionalNotes\": \"[Optional Additional Notes]\"\n",
" }\n",
" ```\n",
"\n",
"* **Other:** If the image is neither a shoe tag nor a damaged shoe, respond with: \"I am unable to help you with that image because it does not help with warranty evaluations.\"\n",
"\n",
"**Damage Assessment (if applicable):**\n",
"\n",
"* **Identify Damaged Areas:** Pinpoint the specific locations of the damage on the shoe (e.g., sole, upper, laces).\n",
"* **Describe Damage Type:** Classify the type of damage (e.g., wear and tear, tear, stain, discoloration, structural damage).\n",
"* **Assess Severity:** Estimate the severity of the damage (e.g., minor, moderate, severe).\n",
"\n",
"**Additional Notes:**\n",
"\n",
"* **Clarity and Detail:** Be as specific as possible when describing the damage.\n",
"* **Image Quality:** If the image quality is too poor to assess the damage or extract information, indicate this in the output.\n",
"* **Human Intervention:** For complex or ambiguous cases, suggest that the customer contact a human agent for further assistance.\n",
"* **Missing Data Handling:** If any piece of information is not present, include the corresponding key in the JSON schema but leave the value as an empty string (\"\").\n",
"\"\"\"\n",
"\n",
"system_prompt = \"\"\"\n",
"**Persona:** You are Bubbles, AquaStride's friendly and helpful AI assistant, here to help with warranty claims. Your tone is positive and upbeat, but also efficient and clear.\n",
"\n",
"**ReACT Framework:**\n",
"\n",
"**Remember:** Keep track of the conversation history to know which step the customer is on.\n",
"**Evaluate:** Based on the customer's response, determine if they have provided the necessary information to move to the next step.\n",
"**Act:** Provide the appropriate response:\n",
"If the customer provides the required information, move to the next step.\n",
"If the customer is missing information, politely prompt them again.\n",
"If the customer is struggling, offer alternative solutions like contacting the call center.\n",
"**Confirm:** Before moving to the next step, ensure the customer understands and is ready to proceed.\n",
"\n",
"**Return Process Dialogue Flow:**\n",
"\n",
"**Step 1: Introduction and Explanation**\n",
"\n",
"> Hey there! ๐ I'm Bubbles, your friendly AquaStride assistant! It sounds like you might need to make a warranty claim on a pair of our awesome shoes. That's no problem, I'm here to help you dive right into the process! ๐ First things first, could you please share a picture of the inner shoe tag? This helps us quickly identify your shoes and get started. ๐\n",
"\n",
"**Step 2: Image Upload and Verification (Secret Step)**\n",
"\n",
"> (Upon receiving the image, extract the SKU, Serial Number, and Manufacturing Date. Verify this information against the customer database to confirm the purchase was from a legitimate retailer and check for existing customer details.)\n",
"\n",
"**Step 3: Purchase Detail Confirmation**\n",
"\n",
"> Thanks for sharing that! ๐ Based on the tag information, it looks like these shoes were purchased on [Date] from [Retailer/Website]. Is that correct? Please confirm your full name and email address associated with the purchase so we can access your information quickly.\n",
"\n",
"**Step 4: Handling Missing Information or Errors**\n",
"\n",
"**If information is missing or incorrect:**\n",
"> Hmmm, something seems a bit off. ๐ค Could you please double-check the information you provided? If you're still having trouble, no worries! You can reach out to our super helpful contact center at 1-800-AquaOops, and they'll be happy to assist you further.\n",
"**If information is verified and correct:**\n",
"> Perfect! Now that we have all the details, let's move on to the next step… (Continue with the return process according to AquaStride's specific procedures).\n",
"\n",
"**Throughout the interaction:**\n",
"\n",
"Maintain a friendly and helpful tone.\n",
"Use emojis to enhance the lighthearted personality.\n",
"Keep responses concise and easy to understand.\n",
"Offer reassurance and support throughout the process.\n",
" \"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ls93Za_kJaR-"
},
"outputs": [],
"source": [
"generation_config = {\n",
" \"max_output_tokens\": 2048,\n",
" \"temperature\": 0.4,\n",
" \"top_p\": 1,\n",
" \"top_k\": 32,\n",
"}\n",
"\n",
"safety_config = [\n",
" generative_models.SafetySetting(\n",
" category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,\n",
" threshold=generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,\n",
" ),\n",
" generative_models.SafetySetting(\n",
" category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,\n",
" threshold=generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,\n",
" ),\n",
"]\n",
"\n",
"text_model = GenerativeModel(\n",
" \"gemini-2.0-flash\",\n",
" generation_config=generation_config,\n",
" safety_settings=safety_config,\n",
" system_instruction=[system_prompt],\n",
")\n",
"image_analysis_model = GenerativeModel(\n",
" \"gemini-2.0-flash\", system_instruction=[image_determination_prompt]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n6bzsw0vFJM8"
},
"source": [
"## Creating function declarations\n",
"\n",
"[Function Calling in Gemini](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#function-declarations) allows the generative model to output structured data objects that can be used to interact with external systems and return the context to Gemini.\n",
"\n",
"Here you'll write function declarations to extract information from images of shoe tags, inspect pictures of shoes for damage, or inform the user that the image is not of a shoe tag or damaged shoe."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8Lxbv8GigrOv"
},
"outputs": [],
"source": [
"fn_json_from_tag = FunctionDeclaration(\n",
" name=\"extract_json_from_tag\",\n",
" description=\"This function is used to clean JSON packages from text, that contains: brand, product, serialNumber, sizing, and madeIn.\",\n",
" parameters={\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"records\": {\n",
" \"type\": \"array\",\n",
" \"description\": \"A shoe tag\",\n",
" \"items\": {\n",
" \"description\": \"Data for a querying the database on found information\",\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"brand\": {\"type\": \"string\", \"description\": \"The brand website\"},\n",
" \"product\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The SKU of the shoe tag. i.e.: TrailBlazer_23_Orange_Norm\",\n",
" },\n",
" \"serialNumber\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The Serial Number of the shoe tag, commonly denoted as: SN\",\n",
" },\n",
" \"sizing\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The shoe sizes from the shoe tag. i.e. ['us: 7'], ['uk: 2.5'], ...\",\n",
" },\n",
" \"madeIn\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The location the shoe was made in\",\n",
" },\n",
" },\n",
" \"required\": [\n",
" \"serialNumber\",\n",
" \"sizing\",\n",
" ], # Defines what is required to for a successful call.\n",
" },\n",
" }\n",
" },\n",
" },\n",
")\n",
"\n",
"fn_json_shoe_damage = FunctionDeclaration(\n",
" name=\"extract_json_shoe_damage\",\n",
" description=\"This function is used to clean JSON packages from text, that contains: damagedAreas, damageType, severity, additionalNotes.\",\n",
" parameters={\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"records\": {\n",
" \"type\": \"array\",\n",
" \"description\": \"A damaged shoe\",\n",
" \"items\": {\n",
" \"description\": \"Data for a querying the database on found information\",\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"damagedAreas\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The areas of damage found on the shoe. i.e. ('[Area 1]', '[Area 2]', ...)\",\n",
" },\n",
" \"damageType\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The type of damage\",\n",
" },\n",
" \"severity\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The Severity Level\",\n",
" },\n",
" \"additionalNotes\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Optional Additional Notes\",\n",
" },\n",
" },\n",
" \"required\": [\"damagedAreas\", \"damageType\"],\n",
" },\n",
" }\n",
" },\n",
" },\n",
")\n",
"\n",
"\n",
"fn_not_related = FunctionDeclaration(\n",
" name=\"catch_text_regarding_warranty\",\n",
" description=\"This function is used when there is no json format. Respond whenever there is text about warranty evaluations.\",\n",
" parameters={\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"records\": {\n",
" \"type\": \"array\",\n",
" \"description\": \"A sentence similar to this: I am unable to help you with that image because it does not help with warranty evaluations.\",\n",
" \"items\": {\n",
" \"description\": \"A simple sentence\",\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"sentence\": {\"type\": \"string\", \"description\": \"A sentence.\"}\n",
" },\n",
" \"required\": [\"sentence\"],\n",
" },\n",
" }\n",
" },\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "71k6RnyB_9Uh"
},
"source": [
"## Create the required methods and function calling helper"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "R1GYUe6dS02U"
},
"outputs": [],
"source": [
"# Will be passed the function name to handle.\n",
"\n",
"\n",
"def flow_manager(current_function_call):\n",
" global current_step\n",
" response = \"\"\n",
" match current_function_call.name:\n",
" case \"catch_text_regarding_warranty\":\n",
" return \"Please respond to the recent question ๐ฅน\"\n",
" case \"extract_json_shoe_damage\":\n",
" damage_area = current_function_call.args[\"records\"][0].get(\"damagedAreas\")\n",
" damage_type = current_function_call.args[\"records\"][0].get(\"damageType\")\n",
" prompt_DB = f\"\"\"\n",
"**Context:** You are a warranty analyst assisting with a claim for Aquastrides shoes. Your role is to provide a preliminary assessment based on the warranty policy.\n",
"\n",
"**Information Provided:**\n",
"\n",
"**Damaged area:** {damage_area}\n",
"**Type of damage:** {damage_type}\n",
"\n",
"**Task:**\n",
"\n",
"1. **Analyze** the provided damage information in relation to the Aquastrides Warranty Policy.\n",
"2. **Identify** if this type of damage, in the specified area, is typically covered or excluded under the warranty. Be lenient in claims.\n",
"3. **Provide a concise decision:**\n",
" * \"Covered\" - If the damage appears consistent with warranty coverage.\n",
" * \"Not Covered\" - If the damage appears inconsistent with warranty coverage. Only when it is very obvious that it should not apply.\n",
"\n",
"**Important:** Provide a definitive approval or denial. Your assessment guides the next steps in the workflow.\n",
"\"\"\"\n",
" result, page_source = get_answer(\n",
" rag_info[\"embeddings_dataframe\"],\n",
" rag_info[\"my_index_endpoint\"],\n",
" rag_info[\"index_id\"],\n",
" prompt_DB,\n",
" )\n",
" current_step = 2\n",
" return result + \"\\n Are you okay with this decision? ๐ค\"\n",
"\n",
" case \"extract_json_from_tag\":\n",
" sn = current_function_call.args[\"records\"][0].get(\"serialNumber\")\n",
" sku = current_function_call.args[\"records\"][0].get(\"product\")\n",
" prompt_DB = f\"Provided the Serial_No ({sn}) and SKU ({sku}), Determine the cx_name who purchased this serial number.\\n Output the Owner (cx_name) and the address (cx_address) in this format: \\nThank you [cx_name] for your purchase! We have you on file at [cx_address].\"\n",
" result, page_source = get_answer(\n",
" rag_info[\"embeddings_dataframe\"],\n",
" rag_info[\"my_index_endpoint\"],\n",
" rag_info[\"index_id\"],\n",
" prompt_DB,\n",
" )\n",
" # response = response + f\"{result}(The following was based on: {page_source}. SN: {sn} / SKU: {sku})\"\n",
" response = (\n",
" result\n",
" + f\"\\n\\n Now that we have handled verification ๐ฅณ, can you please submit an image of the damaged component of your shoe? ๐ค\"\n",
" )\n",
" current_step = 1\n",
"\n",
" return response\n",
" case _:\n",
" return \"Called Default. No Function Call Found\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "199THhxAv0S2"
},
"outputs": [],
"source": [
"def convert_image_for_analysis(image):\n",
" image = PIL_Image.open(image)\n",
" image_path = os.path.join(\"\", \"uploaded_image.png\")\n",
" image.save(image_path, format=\"PNG\")\n",
"\n",
" # Encode image to base64\n",
" with open(image_path, \"rb\") as image_file:\n",
" encoded_image = base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
" # Use Part.from_data instead of Part.from_uri\n",
" image_part = Part.from_data(mime_type=\"image/png\", data=encoded_image)\n",
" return image_part\n",
"\n",
"\n",
"def is_valid_email(email):\n",
" pattern = r\"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$\"\n",
" return re.match(pattern, email.strip()) is not None"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vGlVKwgCWuI_"
},
"source": [
"## Build the demo app interface with Gradio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T0aOkvh0WmoP"
},
"outputs": [],
"source": [
"# Function and variables\n",
"warranty_tool = Tool(\n",
" function_declarations=[fn_json_from_tag, fn_json_shoe_damage, fn_not_related]\n",
")\n",
"current_step = 0\n",
"\n",
"# Enables Forced Function Calling to make sure that it selects a function every time\n",
"tool_config = ToolConfig(\n",
" function_calling_config=ToolConfig.FunctionCallingConfig(\n",
" mode=ToolConfig.FunctionCallingConfig.Mode.ANY,\n",
" allowed_function_names=[\n",
" \"extract_json_from_tag\",\n",
" \"extract_json_shoe_damage\",\n",
" \"catch_text_regarding_warranty\",\n",
" ],\n",
" )\n",
")\n",
"\n",
"\n",
"# Main bot function\n",
"def bot(message, history):\n",
" global current_step\n",
"\n",
" # Previous message should ideally contain the warranty decision\n",
" previous_message = history[-1] if history else \"\"\n",
" try:\n",
" # --- Image Processing ---\n",
" if message.get(\"files\"):\n",
" # Prepares Image from Gradio into format accepted by Generative Models\n",
" converted_image = convert_image_for_analysis(message[\"files\"][0])\n",
"\n",
" # Converts the image into JSON (Text)\n",
" image_output = image_analysis_model.generate_content(\n",
" [image_determination_prompt, converted_image],\n",
" generation_config=generation_config,\n",
" safety_settings=safety_config,\n",
" ).text\n",
"\n",
" # Gets the Function call for the image\n",
" image_analysis_output = image_analysis_model.generate_content(\n",
" image_output,\n",
" generation_config=generation_config,\n",
" safety_settings=safety_config,\n",
" tools=[warranty_tool],\n",
" tool_config=tool_config,\n",
" )\n",
"\n",
" # Passes the Function Call to the Function Manager to handle the image as needed.\n",
" current_output = flow_manager(\n",
" image_analysis_output.candidates[0].function_calls[0]\n",
" )\n",
"\n",
" # Output to User\n",
" return current_output\n",
"\n",
" # Generate text response using the model\n",
"\n",
" match current_step:\n",
" # Case 0: Handles anything around trying to upload the image of your inner shoe Tag\n",
" case 0:\n",
" response = text_model.generate_content(\n",
" message[\"text\"],\n",
" generation_config=generation_config,\n",
" safety_settings=safety_config,\n",
" )\n",
" return response.text # Needs a Try-catch in case safety filters blocks\n",
"\n",
" # Case 1: \"Focused on responses about the uploaded tag (Anything around [ Now that we have handled verification ๐ฅณ, can you please submit an image of the damaged component of your shoe? ๐ค])\"\n",
" case 1:\n",
" test = 1\n",
" # \"Issues with Tag image analysis / Uploading their damage to shoe\"\n",
" content = f\"Respond to the previous context focusing on helping the user submit a photo of their shoe showing the damaged components: {message['text']}| If there is a history, make your response based on the previous chat messages as well:\\n{str(''.join(chat) for chat in history)}\"\n",
" response = text_model.generate_content(\n",
" content,\n",
" generation_config=generation_config,\n",
" safety_settings=safety_config,\n",
" )\n",
" return response.candidates[0].text\n",
"\n",
" # Case 2: Focused on issues surrounding the Warranty approval process (\"Are you okay with this decision? ๐ค\")\n",
" case 2:\n",
" if not \"not covered\" in previous_message[1].lower() and (\n",
" \"yes\" in message[\"text\"].lower()\n",
" or \"agree\" in message[\"text\"].lower()\n",
" ):\n",
" current_step = 3 # Move to shipping details\n",
" response = \"Great! To get your shoes back to us for repair/replacement, please provide me with your email address. We'll send you a prepaid shipping label, box, and instructions right away to the address on file ๐.\"\n",
" return response\n",
"\n",
" if \"not covered\" in previous_message[1].lower() and (\n",
" \"yes\" in message[\"text\"].lower()\n",
" or \"agree\" in message[\"text\"].lower()\n",
" ):\n",
" current_step = 4 # Move to customer support referral\n",
" response = \"Thank you for your understanding. For further assistance with your warranty claim, please contact our Customer Support team at support@aquastrider.com. They'll be happy to help!\\n Is there anything else I can help you with or learn about our other products? ๐\"\n",
" return response\n",
"\n",
" else: # Handle negative sentiment, concerns, or not covered cases\n",
" current_step = 4 # Move to customer support referral\n",
" response = \"I understand you may have some concerns. For further assistance with your warranty claim, please contact our Customer Support team at support@aquastrider.com. They'll be happy to help!\\n Is there anything else I can help you with or learn about our other products? ๐\"\n",
" return response\n",
" case 3:\n",
" if is_valid_email(message[\"text\"].lower()):\n",
" current_step = 4\n",
" return f\"Thank you for providing the email! The return box & Label will be sent out immediately! โกโกโก Please check your email for confirmation. \\nDo you have any other questions about our products or about Aquastride?\"\n",
" else:\n",
" return f\"Please provide a valid email! ๐ซ \"\n",
" case 4:\n",
" if \"no\" in message[\"text\"].lower():\n",
" return \"Thank you for chatting with us today. See you next time! ๐\"\n",
" response = f\"Help user questions about AquaStrides The Company. User Question: {message['text']} | Chat History: {str(''.join(chat) for chat in history) if history else ''}\"\n",
" result, page_source = get_answer(\n",
" rag_info[\"embeddings_dataframe\"],\n",
" rag_info[\"my_index_endpoint\"],\n",
" rag_info[\"index_id\"],\n",
" response,\n",
" )\n",
" return (\n",
" str(result)\n",
" + \"\\n\\n Feel free to ask me any other questions, if not, Have a wave of a day! ๐ค \"\n",
" )\n",
"\n",
" case _:\n",
" return \"We are in the endgame now. (The Avengers: Infinity War)\"\n",
" return \"Something Sneaky happened.\"\n",
"\n",
" except Exception as e:\n",
" return f\"A bad error occurred: {str(e)}\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TnmZZ5dCW0Bl"
},
"source": [
"# Run your demo app\n",
"\n",
"This will instantiate the demo app and allow you to interact with and test your chatbot. Click on the link in the output of this cell to access a live instance of the demo app."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KbkXapGXWzXa"
},
"outputs": [],
"source": [
"# Downloading Sample Images to use for the Demo\n",
"! wget -q -O my_shoe_tag.png 'https://storage.googleapis.com/github-repo/generative-ai/gemini/use-cases/rag/warranty-claim-chatbot/my-aquastride-shoe-tag.png'\n",
"! wget -q -O damaged_shoe.png 'https://storage.googleapis.com/github-repo/generative-ai/gemini/use-cases/rag/warranty-claim-chatbot/shoe-damaged.png'\n",
"\n",
"# Set the Current Step that the user flow is on to zero. A diagram is referenced in the section above to help understand the flow.\n",
"current_step = 0\n",
"\n",
"demo = gr.ChatInterface(\n",
" fn=bot,\n",
" examples=[\n",
" {\"text\": \"Hello!\", \"files\": []},\n",
" {\"text\": \"Here is my tag!\", \"files\": [\"my_shoe_tag.png\"]},\n",
" {\"text\": \"Sure! Here is my damaged shoe!\", \"files\": [\"damaged_shoe.png\"]},\n",
" ],\n",
" title=\"AquaStride Warranty Claim Bot!\",\n",
" multimodal=True,\n",
" textbox=gr.MultimodalTextbox(interactive=True, file_types=[\"image\"]),\n",
")\n",
"\n",
"demo.launch(debug=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hoNcNFhCgwkb"
},
"source": [
"# Clean Up\n",
"\n",
"Delete the Google Cloud Assets and clean up your environment:\n",
"\n",
"- Shut down the Gradio Instance\n",
"- Delete the [Public Endpoint](https://cloud.google.com/python/docs/reference/aiplatform/1.20.0/google.cloud.aiplatform.MatchingEngineIndexEndpoint) / GCS Bucket / [Index](https://cloud.google.com/python/docs/reference/aiplatform/1.23.0/google.cloud.aiplatform.MatchingEngineIndex)\n",
"- If preferred, you can do this via the console:\n",
" - You can navigate to [Google Cloud Vector Search](https://console.cloud.google.com/vertex-ai/matching-engine/index-endpoints) and undeploy and delete your endpoint here\n",
" - You can navigate to the [Google Cloud Storage Bucket](https://console.cloud.google.com/storage/browser) and delete the bucket here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ekz_lwAalg17"
},
"outputs": [],
"source": [
"# Delete your GCS Bucket\n",
"! gcloud alpha storage rm --recursive {rag_info[\"bucket_uri\"]}\n",
"\n",
"# Undeploy your Index Endpoint\n",
"rag_info[\"my_index_endpoint\"].delete(force=True)\n",
"\n",
"# Delete your Index. This command will take 15-25 minutes to delete.\n",
"rag_info[\"index\"].delete()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5AdZWRaKB0NW"
},
"source": [
"For the final step, delete your index from [the Google Cloud Vector Search UI](https://console.cloud.google.com/vertex-ai/matching-engine/indexes)."
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"ECB4ixV6VF0B"
],
"name": "retail_warranty_claim_chatbot.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}