infrastructure/movie-search-app/movie_search.py (385 lines of code) (raw):

# Copyright 2024 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import mesop as me from data_model import State, Models, ModelDialogState, Conversation, ChatMessage import gemini_model import pinecone_model import os import logging import json import sqlalchemy from connect_tcp import connect_tcp_socket from dataclasses import dataclass, field from typing import Literal Role = Literal["user", "model"] # Dialog @me.content_component def dialog(is_open: bool): with me.box( style=me.Style( background="rgba(0,0,0,0.4)", display="block" if is_open else "none", height="100%", overflow_x="auto", overflow_y="auto", position="fixed", width="100%", z_index=1000, ) ): with me.box( style=me.Style( align_items="center", display="grid", height="100vh", justify_items="center", ) ): with me.box( style=me.Style( background="#fff", border_radius=20, box_sizing="content-box", box_shadow=( "0 3px 1px -2px #0003, 0 2px 2px #00000024, 0 1px 5px #0000001f" ), margin=me.Margin.symmetric(vertical="0", horizontal="auto"), padding=me.Padding.all(20), ) ): me.slot() @me.content_component def dialog_actions(): with me.box( style=me.Style( display="flex", justify_content="end", margin=me.Margin(top=20) ) ): me.slot() # App Role = Literal["user", "model"] _ROLE_USER = "user" _ROLE_ASSISTANT = "model" _COLOR_BACKGROUND = me.theme_var("background") _COLOR_CHAT_BUBBLE_YOU = me.theme_var("surface-container-low") _COLOR_CHAT_BUBBLE_BOT = me.theme_var("secondary-container") _DEFAULT_PADDING = me.Padding.all(20) _DEFAULT_BORDER_SIDE = me.BorderSide( width="1px", style="solid", color=me.theme_var("secondary-fixed") ) _STYLE_APP_CONTAINER = me.Style( background=_COLOR_BACKGROUND, display="flex", flex_direction="column", height="100%", margin=me.Margin.symmetric(vertical=0, horizontal="auto"), width="min(1024px, 100%)", box_shadow=("0 3px 1px -2px #0003, 0 2px 2px #00000024, 0 1px 5px #0000001f"), padding=me.Padding(top=20, left=20, right=20), ) _STYLE_TITLE = me.Style(padding=me.Padding(left=10)) _STYLE_CHAT_BOX = me.Style( flex_grow=1, overflow_y="scroll", padding=_DEFAULT_PADDING, margin=me.Margin(bottom=20), border_radius="10px", border=me.Border( left=_DEFAULT_BORDER_SIDE, right=_DEFAULT_BORDER_SIDE, top=_DEFAULT_BORDER_SIDE, bottom=_DEFAULT_BORDER_SIDE, ), ) _STYLE_CHAT_INPUT = me.Style(width="100%") _STYLE_CHAT_INPUT_BOX = me.Style( padding=me.Padding(top=30), display="flex", flex_direction="row" ) _STYLE_CHAT_BUTTON = me.Style(margin=me.Margin(top=8, left=8)) _STYLE_CHAT_BUBBLE_NAME = me.Style( font_weight="bold", font_size="13px", padding=me.Padding(left=15, right=15, bottom=5), ) _STYLE_CHAT_BUBBLE_PLAINTEXT = me.Style(margin=me.Margin.symmetric(vertical=15)) _LABEL_BUTTON = "send" _LABEL_BUTTON_IN_PROGRESS = "pending" _LABEL_INPUT = "Enter your prompt" def _make_style_chat_bubble_wrapper(role: Role) -> me.Style: """Generates styles for chat bubble position. Args: role: Chat bubble alignment depends on the role """ align_items = "end" if role == _ROLE_USER else "start" return me.Style( display="flex", flex_direction="column", align_items=align_items, ) def _make_chat_bubble_style(role: Role) -> me.Style: """Generates styles for chat bubble. Args: role: Chat bubble background color depends on the role """ background = ( _COLOR_CHAT_BUBBLE_YOU if role == _ROLE_USER else _COLOR_CHAT_BUBBLE_BOT ) return me.Style( width="80%", font_size="16px", line_height="1.5", background=background, border_radius="15px", padding=me.Padding(right=15, left=15, bottom=3), margin=me.Margin(bottom=10), border=me.Border( left=_DEFAULT_BORDER_SIDE, right=_DEFAULT_BORDER_SIDE, top=_DEFAULT_BORDER_SIDE, bottom=_DEFAULT_BORDER_SIDE, ), ) db = None logger = logging.getLogger() def init_connection_pool() -> sqlalchemy.engine.base.Engine: """Sets up connection pool for the app.""" if os.environ.get("INSTANCE_HOST"): db_host = os.environ[ 'INSTANCE_HOST' ] # e.g. '127.0.0.1' ('172.17.0.1' if deployed to GAE Flex) else: db_host = "127.0.0.1" logging.warning("INSTANCE_HOST is not set using default: %s", db_host) print("INSTANCE_HOST is not set using default: %s", db_host) # use a TCP socket when INSTANCE_HOST (e.g. 127.0.0.1) is defined if db_host: return connect_tcp_socket() # # use the connector when INSTANCE_CONNECTION_NAME (e.g. project:region:instance) is defined # if os.environ.get("INSTANCE_CONNECTION_NAME"): # # Either a DB_USER or a DB_IAM_USER should be defined. If both are # # defined, DB_IAM_USER takes precedence. # return ( # connect_with_connector_auto_iam_authn() # if os.environ.get("DB_IAM_USER") # else connect_with_connector() # ) # raise ValueError( # "Missing database connection type. Please define one of INSTANCE_HOST, INSTANCE_UNIX_SOCKET, or INSTANCE_CONNECTION_NAME" # ) def init_db() -> sqlalchemy.engine.base.Engine: """Initiates connection to database and its structure.""" global db if db is None: db = init_connection_pool() def get_movies(db: sqlalchemy.engine.base.Engine, embeddings: str) -> dict: movies=[] stmt = sqlalchemy.text( """ SELECT mj.langchain_metadata->'title' as title, mj.langchain_metadata->'summary' as summary, mj.langchain_metadata->'director' as director, mj.langchain_metadata->'actors' as actors, (mj.embedding <=> (:embeddings)::vector) as distance FROM alloydb_table mj ORDER BY distance ASC LIMIT 5; """ ) try: with db.connect() as conn: app_movies = conn.execute(stmt, parameters={"embeddings": embeddings}).fetchall() except Exception as e: logger.exception(e) for row in app_movies: movies.append({"title":row[0],"summary":row[1],"director":row[2],"actors": row[3]}) return movies def change_model_option(e: me.CheckboxChangeEvent): s = me.state(ModelDialogState) if e.checked: s.selected_models.append(e.key) else: s.selected_models.remove(e.key) def set_gemini_api_key(e: me.InputBlurEvent): me.state(State).gemini_api_key = e.value def set_pinecone_api_key(e: me.InputBlurEvent): me.state(State).pinecone_api_key = e.value def model_picker_dialog(): state = me.state(State) with dialog(state.is_model_picker_dialog_open): with me.box(style=me.Style(display="flex", flex_direction="column", gap=12)): me.text("API keys") me.input( label="Gemini API Key", value=state.gemini_api_key, on_blur=set_gemini_api_key, ) me.input( label="Pinecone API Key", value=state.pinecone_api_key, on_blur=set_pinecone_api_key, ) me.text("Pick a backend") for model in Models: if model.name.startswith("GEMINI"): disabled = not state.gemini_api_key elif model.name.startswith("PINECONE"): disabled = not state.pinecone_api_key or not state.gemini_api_key else: disabled = False me.checkbox( key=model.value, label=model.value, checked=model.value in state.models, disabled=disabled, on_change=change_model_option, style=me.Style( display="flex", flex_direction="column", gap=4, padding=me.Padding(top=12), ), ) with dialog_actions(): me.button("Cancel", on_click=close_model_picker_dialog) me.button("Confirm", on_click=confirm_model_picker_dialog) def close_model_picker_dialog(e: me.ClickEvent): state = me.state(State) state.is_model_picker_dialog_open = False def confirm_model_picker_dialog(e: me.ClickEvent): dialog_state = me.state(ModelDialogState) state = me.state(State) state.is_model_picker_dialog_open = False state.models = dialog_state.selected_models ROOT_BOX_STYLE = me.Style( background="#e7f2ff", height="100%", font_family="Inter", display="flex", flex_direction="column", ) @me.page( path="/", stylesheets=[ "https://fonts.googleapis.com/css2?family=Inter:wght@100..900&display=swap" ], title = "Movie Search Assistant" ) def page(): bot_user = "model" print("starting") global db # initialize db within request context if not db: # initiate a connection pool to a Postgres database db = init_connection_pool() model_picker_dialog() def toggle_theme(e: me.ClickEvent): if me.theme_brightness() == "light": me.set_theme_mode("dark") else: me.set_theme_mode("light") def on_input_enter(e: me.InputEnterEvent): state = me.state(State) state.input = e.value print(state.input) yield from send_prompt(e) with me.box(style=_STYLE_APP_CONTAINER): with me.content_button( type="icon", style=me.Style(position="absolute", right=4, top=8), on_click=toggle_theme, ): me.icon("light_mode" if me.theme_brightness() == "dark" else "dark_mode") title = "Movie Search Virtual Assistant" if title: me.text(title, type="headline-5", style=_STYLE_TITLE) with me.box(style=_STYLE_CHAT_BOX): state = me.state(State) for conversation in state.conversations: for message in conversation.messages: with me.box(style=_make_style_chat_bubble_wrapper(message.role)): if message.role == _ROLE_ASSISTANT: me.text(bot_user, style=_STYLE_CHAT_BUBBLE_NAME) with me.box(style=_make_chat_bubble_style(message.role)): if message.role == _ROLE_USER: me.text(message.content, style=_STYLE_CHAT_BUBBLE_PLAINTEXT) else: me.markdown(message.content) with me.box(style=_STYLE_CHAT_INPUT_BOX): with me.box(style=me.Style(flex_grow=1)): me.input( label=_LABEL_INPUT, # Workaround: update key to clear input. key=f"input-{len(state.conversations)}", on_blur=on_blur, on_enter=on_input_enter, style=_STYLE_CHAT_INPUT, ) with me.box( style=me.Style( display="flex", padding=me.Padding(left=12, bottom=12), cursor="pointer", ), on_click=switch_model, ): me.text( "Backend:", style=me.Style(font_weight=500, padding=me.Padding(right=6)), ) if state.models: me.text(", ".join(state.models)) else: me.text("(no backend selected)") with me.content_button( color="primary", type="flat", disabled=state.in_progress, on_click=send_prompt, style=_STYLE_CHAT_BUTTON, ): me.icon( _LABEL_BUTTON_IN_PROGRESS if state.in_progress else _LABEL_BUTTON ) def switch_model(e: me.ClickEvent): state = me.state(State) state.is_model_picker_dialog_open = True dialog_state = me.state(ModelDialogState) dialog_state.selected_models = state.models[:] def on_blur(e: me.InputBlurEvent): state = me.state(State) state.input = e.value def send_prompt(e: me.ClickEvent): state = me.state(State) if not state.conversations: for model in state.models: state.conversations.append(Conversation(model=model, messages=[])) input = state.input state.input = "" yield for conversation in state.conversations: model = conversation.model messages = conversation.messages history = messages[:] messages.append(ChatMessage(role="user", content=input)) messages.append(ChatMessage(role="model", in_progress=True)) yield if model == Models.GEMINI_1_5_FLASH.value: while True: intent_str = gemini_model.classify_intent(input) print(intent_str) logging.info(f"MOVIES LIST: {intent_str}") try: json_intent = json.loads(intent_str) except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") continue break if json_intent["shouldRecommendMovie"] is True: search_embedding = gemini_model.generate_embedding(json_intent["summary"]) movies_list = get_movies(db, str(search_embedding["embedding"])) logging.info(f"MOVIES LIST: {movies_list}") print(movies_list) persona="You are friendly assistance helping to find a movie or show based on the client's request" safeguards="You should give information about the movie or show, year, main actors and any supplemental information. Do not invent any new movies, names and use for the answer the list of shows defined in the context" context=""" Based on the client request we have loaded a list of shows closely related to search. The list in JSON format with list of values like {"title":"Sparring","summary":"some description","director":"somebody","genre": "Drama", "actors": "Mathieu Kassovitz, Souleymane M'Baye"} Here is the list of shows:\n """+str(movies_list) system_instruction=[persona,safeguards,context] else: persona="You are friendly assistance helping to find a movie or show based on the client's request" safeguards="You should give information about the movie or show, year, main actors and any supplemental information. Do not invent any new movies, names and use for the answer the list of shows defined in the context" system_instruction=[persona,safeguards] llm_response = gemini_model.send_prompt_flash(input, history,system_instruction) elif model == Models.PINECONE.value: while True: intent_str = pinecone_model.classify_intent(input) print(intent_str) logging.info(f"INTENT: {intent_str}") try: json_intent = json.loads(intent_str) except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") continue break if json_intent["shouldRecommendMovie"] is True: search_embedding = pinecone_model.generate_embedding(json_intent["summary"]) movies_list = pinecone_model.get_movies(search_embedding["embedding"]) logging.info(f"MOVIES LIST: {movies_list}") print(movies_list) persona="You are friendly assistance helping to find a movie or show based on the client's request" safeguards="You should give information about the movie or show, year, main actors and any supplemental information. Do not invent any new movies, names and use for the answer the list of shows defined in the context" context=""" Based on the client request we have loaded a list of shows closely related to search. The list in JSON format with list of values like {"title":"Sparring","summary":"some description","director":"somebody","genre": "Drama", "actors": "Mathieu Kassovitz, Souleymane M'Baye"} Here is the list of shows:\n """+str(movies_list) system_instruction=[persona,safeguards,context] else: persona="You are friendly assistance helping to find a movie or show based on the client's request" safeguards="You should give information about the movie or show, year, main actors and any supplemental information. Do not invent any new movies, names and use for the answer the list of shows defined in the context" system_instruction=[persona,safeguards] llm_response = pinecone_model.send_prompt_flash(input, history,system_instruction) # llm_response = pinecone.call_pinecone(input, history) else: raise Exception("Unhandled model", model) for chunk in llm_response: messages[-1].content += chunk yield messages[-1].in_progress = False yield