infrastructure/cymbal-store-embeddings/cymbal_store.py (411 lines of code) (raw):

# Copyright 2024 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import mesop as me from data_model import State, Models, ModelDialogState, Conversation, ChatMessage import gemini_model import openai_model import gemma_model import os import logging import json import sqlalchemy from connect_tcp import connect_tcp_socket from dataclasses import dataclass, field from typing import Literal import base64 Role = Literal["user", "model"] # Dialog @me.content_component def dialog(is_open: bool): with me.box( style=me.Style( background="rgba(0,0,0,0.4)", display="block" if is_open else "none", height="100%", overflow_x="auto", overflow_y="auto", position="fixed", width="100%", z_index=1000, ) ): with me.box( style=me.Style( align_items="center", display="grid", height="100vh", justify_items="center", ) ): with me.box( style=me.Style( background="#fff", border_radius=20, box_sizing="content-box", box_shadow=( "0 3px 1px -2px #0003, 0 2px 2px #00000024, 0 1px 5px #0000001f" ), margin=me.Margin.symmetric(vertical="0", horizontal="auto"), padding=me.Padding.all(20), ) ): me.slot() @me.content_component def dialog_actions(): with me.box( style=me.Style( display="flex", justify_content="end", margin=me.Margin(top=20) ) ): me.slot() # App Role = Literal["user", "model"] _ROLE_USER = "user" _ROLE_ASSISTANT = "model" _COLOR_BACKGROUND = me.theme_var("background") _COLOR_CHAT_BUBBLE_YOU = me.theme_var("surface-container-low") _COLOR_CHAT_BUBBLE_BOT = me.theme_var("secondary-container") _DEFAULT_PADDING = me.Padding.all(20) _DEFAULT_BORDER_SIDE = me.BorderSide( width="1px", style="solid", color=me.theme_var("secondary-fixed") ) _STYLE_APP_CONTAINER = me.Style( background=_COLOR_BACKGROUND, display="flex", flex_direction="column", height="100%", margin=me.Margin.symmetric(vertical=0, horizontal="auto"), width="min(1024px, 100%)", box_shadow=("0 3px 1px -2px #0003, 0 2px 2px #00000024, 0 1px 5px #0000001f"), padding=me.Padding(top=20, left=20, right=20), ) _STYLE_TITLE = me.Style(padding=me.Padding(left=10)) _STYLE_CHAT_BOX = me.Style( flex_grow=1, overflow_y="scroll", padding=_DEFAULT_PADDING, margin=me.Margin(bottom=20), border_radius="10px", border=me.Border( left=_DEFAULT_BORDER_SIDE, right=_DEFAULT_BORDER_SIDE, top=_DEFAULT_BORDER_SIDE, bottom=_DEFAULT_BORDER_SIDE, ), ) _STYLE_CHAT_INPUT = me.Style(width="100%") _STYLE_CHAT_INPUT_BOX = me.Style( padding=me.Padding(top=30), display="flex", flex_direction="row" ) _STYLE_CHAT_BUTTON = me.Style(margin=me.Margin(top=8, left=8)) _STYLE_CHAT_BUBBLE_NAME = me.Style( font_weight="bold", font_size="13px", padding=me.Padding(left=15, right=15, bottom=5), ) _STYLE_CHAT_BUBBLE_PLAINTEXT = me.Style(margin=me.Margin.symmetric(vertical=15)) _LABEL_BUTTON = "send" _LABEL_BUTTON_IN_PROGRESS = "pending" _LABEL_INPUT = "Enter your prompt" def _make_style_chat_bubble_wrapper(role: Role) -> me.Style: """Generates styles for chat bubble position. Args: role: Chat bubble alignment depends on the role """ align_items = "end" if role == _ROLE_USER else "start" return me.Style( display="flex", flex_direction="column", align_items=align_items, ) def _make_chat_bubble_style(role: Role) -> me.Style: """Generates styles for chat bubble. Args: role: Chat bubble background color depends on the role """ background = ( _COLOR_CHAT_BUBBLE_YOU if role == _ROLE_USER else _COLOR_CHAT_BUBBLE_BOT ) return me.Style( width="80%", font_size="16px", line_height="1.5", background=background, border_radius="15px", padding=me.Padding(right=15, left=15, bottom=3), margin=me.Margin(bottom=10), border=me.Border( left=_DEFAULT_BORDER_SIDE, right=_DEFAULT_BORDER_SIDE, top=_DEFAULT_BORDER_SIDE, bottom=_DEFAULT_BORDER_SIDE, ), ) db = None logger = logging.getLogger() def init_connection_pool() -> sqlalchemy.engine.base.Engine: """Sets up connection pool for the app.""" # use a TCP socket when INSTANCE_HOST (e.g. 127.0.0.1) is defined if os.environ.get("INSTANCE_HOST"): return connect_tcp_socket() # # use the connector when INSTANCE_CONNECTION_NAME (e.g. project:region:instance) is defined # if os.environ.get("INSTANCE_CONNECTION_NAME"): # # Either a DB_USER or a DB_IAM_USER should be defined. If both are # # defined, DB_IAM_USER takes precedence. # return ( # connect_with_connector_auto_iam_authn() # if os.environ.get("DB_IAM_USER") # else connect_with_connector() # ) raise ValueError( "Missing database connection type. Please define one of INSTANCE_HOST, INSTANCE_UNIX_SOCKET, or INSTANCE_CONNECTION_NAME" ) def init_db() -> sqlalchemy.engine.base.Engine: """Initiates connection to database and its structure.""" global db if db is None: db = init_connection_pool() def get_products(db: sqlalchemy.engine.base.Engine, embeddings: str) -> dict: products=[] stmt = sqlalchemy.text( """ SELECT cp.product_name, left(cp.product_description,80) as description, cp.sale_price, cs.zip_code, (ce.embedding <=> (:embeddings)::vector) as distance FROM cymbal_products cp JOIN cymbal_embedding ce on ce.uniq_id = cp.uniq_id JOIN cymbal_inventory ci on ci.uniq_id=cp.uniq_id JOIN cymbal_stores cs on cs.store_id=ci.store_id AND ci.inventory>0 AND cs.store_id = 1583 ORDER BY distance ASC LIMIT 5; """ ) try: with db.connect() as conn: app_products = conn.execute(stmt, parameters={"embeddings": embeddings}).fetchall() except Exception as e: logger.exception(e) for row in app_products: products.append({"product_name":row[0],"description":row[1],"sale_price":row[2],"zip_code": row[3]}) return products def change_model_option(e: me.CheckboxChangeEvent): s = me.state(ModelDialogState) if e.checked: s.selected_models.append(e.key) else: s.selected_models.remove(e.key) def set_gemini_api_key(e: me.InputBlurEvent): me.state(State).gemini_api_key = e.value def set_openai_api_key(e: me.InputBlurEvent): me.state(State).openai_api_key = e.value def set_gemma_endpoint(e: me.InputBlurEvent): me.state(State).gemma_endpoint_id = e.value def set_tei_endpoint(e: me.InputBlurEvent): me.state(State).tei_embedding_url = e.value def model_picker_dialog(): state = me.state(State) with dialog(state.is_model_picker_dialog_open): with me.box(style=me.Style(display="flex", flex_direction="column", gap=12)): me.text("API keys") me.input( label="Gemini API Key", value=state.gemini_api_key, on_blur=set_gemini_api_key, ) me.input( label="OpenAI API Key", value=state.openai_api_key, on_blur=set_openai_api_key, ) me.input( label="Gemma Endpoint", value=state.gemma_endpoint_id, on_blur=set_gemma_endpoint, ) me.input( label="TEI Endpoint", value=state.tei_embedding_url, on_blur=set_tei_endpoint, ) me.text("Pick a model") for model in Models: if model.name.startswith("GEMINI"): disabled = not state.gemini_api_key elif model.name.startswith("OPENAI"): disabled = not state.openai_api_key elif model.name.startswith("GEMMA"): disabled = not state.gemma_endpoint_id else: disabled = False me.checkbox( key=model.value, label=model.value, checked=model.value in state.models, disabled=disabled, on_change=change_model_option, style=me.Style( display="flex", flex_direction="column", gap=4, padding=me.Padding(top=12), ), ) with dialog_actions(): me.button("Cancel", on_click=close_model_picker_dialog) me.button("Confirm", on_click=confirm_model_picker_dialog) def close_model_picker_dialog(e: me.ClickEvent): state = me.state(State) state.is_model_picker_dialog_open = False def confirm_model_picker_dialog(e: me.ClickEvent): dialog_state = me.state(ModelDialogState) state = me.state(State) state.is_model_picker_dialog_open = False state.models = dialog_state.selected_models # Background image with light and dark theme bimage_light = open("pub/app_background_light.png", "rb") bimage64_light = base64.b64encode(bimage_light.read()).decode("ascii") _BACKGROUND_IMAGE_LIGHT = f"url(data:image/png;base64,{bimage64_light}) center/cover no-repeat" bimage_dark = open("pub/app_background_dark.png", "rb") bimage64_dark = base64.b64encode(bimage_dark.read()).decode("ascii") _BACKGROUND_IMAGE_DARK = f"url(data:image/png;base64,{bimage64_dark}) center/cover no-repeat" _BACKGROUND_IMAGE = _BACKGROUND_IMAGE_LIGHT ROOT_BOX_STYLE = me.Style( background = _BACKGROUND_IMAGE, height="100%", font_family="Inter", display="flex", flex_direction="column", ) @me.page( path="/", stylesheets=[ "https://fonts.googleapis.com/css2?family=Inter:wght@100..900&display=swap" ], title = "Cymbal store assistant" ) def page(): global _BACKGROUND_IMAGE with me.box(style=ROOT_BOX_STYLE): bot_user = "model" global db # initialize db within request context if not db: # initiate a connection pool to a Postgres database db = init_connection_pool() model_picker_dialog() def toggle_theme(e: me.ClickEvent): if me.theme_brightness() == "light": me.set_theme_mode("dark") _BACKGROUND_IMAGE = _BACKGROUND_IMAGE_DARK else: me.set_theme_mode("light") _BACKGROUND_IMAGE = _BACKGROUND_IMAGE_LIGHT def on_input_enter(e: me.InputEnterEvent): state = me.state(State) state.input = e.value print(state.input) yield from send_prompt(e) with me.box(style=_STYLE_APP_CONTAINER): with me.content_button( type="icon", style=me.Style(position="absolute", right=4, top=8), on_click=toggle_theme, ): me.icon("light_mode" if me.theme_brightness() == "dark" else "dark_mode") title = "Store Virtual Assistant" if title: me.text(title, type="headline-5", style=_STYLE_TITLE) with me.box(style=_STYLE_CHAT_BOX): state = me.state(State) for conversation in state.conversations: for message in conversation.messages: with me.box(style=_make_style_chat_bubble_wrapper(message.role)): if message.role == _ROLE_ASSISTANT: me.text(bot_user, style=_STYLE_CHAT_BUBBLE_NAME) with me.box(style=_make_chat_bubble_style(message.role)): if message.role == _ROLE_USER: me.text(message.content, style=_STYLE_CHAT_BUBBLE_PLAINTEXT) else: me.markdown(message.content) with me.box(style=_STYLE_CHAT_INPUT_BOX): with me.box(style=me.Style(flex_grow=1)): me.input( label=_LABEL_INPUT, # Workaround: update key to clear input. key=f"input-{len(state.conversations)}", on_blur=on_blur, on_enter=on_input_enter, style=_STYLE_CHAT_INPUT, ) with me.box( style=me.Style( display="flex", padding=me.Padding(left=12, bottom=12), cursor="pointer", ), on_click=switch_model, ): me.text( "Model:", style=me.Style(font_weight=500, padding=me.Padding(right=6)), ) if state.models: me.text(", ".join(state.models)) else: me.text("(no model selected)") with me.content_button( color="primary", type="flat", disabled=state.in_progress, on_click=send_prompt, style=_STYLE_CHAT_BUTTON, ): me.icon( _LABEL_BUTTON_IN_PROGRESS if state.in_progress else _LABEL_BUTTON ) def switch_model(e: me.ClickEvent): state = me.state(State) state.is_model_picker_dialog_open = True dialog_state = me.state(ModelDialogState) dialog_state.selected_models = state.models[:] def on_blur(e: me.InputBlurEvent): state = me.state(State) state.input = e.value def send_prompt(e: me.ClickEvent): state = me.state(State) if not state.conversations: for model in state.models: state.conversations.append(Conversation(model=model, messages=[])) input = state.input state.input = "" yield for conversation in state.conversations: model = conversation.model messages = conversation.messages history = messages[:] messages.append(ChatMessage(role="user", content=input)) messages.append(ChatMessage(role="model", in_progress=True)) yield if model == Models.GEMINI_2_0_FLASH.value: while True: intent_str = gemini_model.classify_intent(input) print(intent_str) logging.info(f"PRODUCTS LIST: {intent_str}") try: json_intent = json.loads(intent_str) except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") continue break if json_intent["shouldRecommendProduct"] is True: search_embedding = gemini_model.generate_embedding(json_intent["summary"]) products_list = get_products(db, str(search_embedding["embedding"])) logging.info(f"PRODUCTS LIST: {products_list}") print(f"PRODUCTS LIST: {products_list}") persona="You are friendly assistance in a store helping to find a products based on the client's request" safeguards="You should give information about the product, price and any supplemental information. Do not invent any new products and use for the answer the product defined in the context" context=""" Based on the client request we have loaded a list of products closely related to search. The list in JSON format with list of values like {"product_name":"name","description":"some description","sale_price":10,"zip_code": 10234} Here is the list of products:\n """+str(products_list) system_instruction=[persona,safeguards,context] else: persona="You are friendly assistance in a store helping to find a products based on the client's request" safeguards="You should give information about the product, price and any supplemental information. Do not invent any new products and use for the answer the product defined in the context" system_instruction=[persona,safeguards] llm_response = gemini_model.send_prompt_flash(input, history,system_instruction) elif model == Models.OPENAI.value: llm_response = openai_model.call_openai_gpt4o_mini(input, history) elif model == Models.GEMMA_3.value: while True: intent_str = gemma_model.classify_intent(input) print(intent_str) logging.info(f"PRODUCTS LIST: {intent_str}") try: json_intent = json.loads(intent_str) except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") continue break if json_intent["shouldRecommendProduct"] is True: search_embedding = gemma_model.generate_embedding(json_intent["summary"]) products_list = get_products(db, str(search_embedding)) logging.info(f"PRODUCTS LIST: {products_list}") print(f"PRODUCTS LIST: {products_list}") persona="You are friendly assistance in a store helping to find a products based on the client's request" safeguards="You should give information about the product, price and any supplemental information. Do not invent any new products and use for the answer the product defined in the context" context=""" Based on the client request we have loaded a list of products closely related to search. The list in JSON format with list of values like {"product_name":"name","description":"some description","sale_price":10,"zip_code": 10234} Here is the list of products:\n """+str(products_list) system_instruction=f'{persona}{safeguards}{context}' else: search_embedding = None persona="You are friendly assistance in a store helping to find a products based on the client's request" safeguards="You should give information about the product, price and any supplemental information. Do not invent any new products and use for the answer the product defined in the context" system_instruction=f'{persona}{safeguards}' llm_response = gemma_model.call_gemma(input, history,system_instruction) else: raise Exception("Unhandled model", model) for chunk in llm_response: messages[-1].content += chunk yield messages[-1].in_progress = False yield