# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mesop as me
from data_model import State, Models, ModelDialogState, Conversation, ChatMessage
import gemini_model
import openai_model
import gemma_model
import os
import logging
import json
import sqlalchemy
from connect_tcp import connect_tcp_socket
from dataclasses import dataclass, field
from typing import Literal
import base64

Role = Literal["user", "model"]

# Dialog
@me.content_component
def dialog(is_open: bool):
    with me.box(
        style=me.Style(
            background="rgba(0,0,0,0.4)",
            display="block" if is_open else "none",
            height="100%",
            overflow_x="auto",
            overflow_y="auto",
            position="fixed",
            width="100%",
            z_index=1000,
        )
    ):
        with me.box(
            style=me.Style(
                align_items="center",
                display="grid",
                height="100vh",
                justify_items="center",
            )
        ):
            with me.box(
                style=me.Style(
                    background="#fff",
                    border_radius=20,
                    box_sizing="content-box",
                    box_shadow=(
                        "0 3px 1px -2px #0003, 0 2px 2px #00000024, 0 1px 5px #0000001f"
                    ),
                    margin=me.Margin.symmetric(vertical="0", horizontal="auto"),
                    padding=me.Padding.all(20),
                )
            ):
                me.slot()

@me.content_component
def dialog_actions():
    with me.box(
        style=me.Style(
            display="flex", justify_content="end", margin=me.Margin(top=20)
        )
    ):
        me.slot()

# App 
Role = Literal["user", "model"]
_ROLE_USER = "user"
_ROLE_ASSISTANT = "model"

_COLOR_BACKGROUND = me.theme_var("background")
_COLOR_CHAT_BUBBLE_YOU = me.theme_var("surface-container-low")
_COLOR_CHAT_BUBBLE_BOT = me.theme_var("secondary-container")

_DEFAULT_PADDING = me.Padding.all(20)
_DEFAULT_BORDER_SIDE = me.BorderSide(
  width="1px", style="solid", color=me.theme_var("secondary-fixed")
)

_STYLE_APP_CONTAINER = me.Style(
  background=_COLOR_BACKGROUND,
  display="flex",
  flex_direction="column",
  height="100%",
  margin=me.Margin.symmetric(vertical=0, horizontal="auto"),
  width="min(1024px, 100%)",
  box_shadow=("0 3px 1px -2px #0003, 0 2px 2px #00000024, 0 1px 5px #0000001f"),
  padding=me.Padding(top=20, left=20, right=20),
)
_STYLE_TITLE = me.Style(padding=me.Padding(left=10))

_STYLE_CHAT_BOX = me.Style(
  flex_grow=1,
  overflow_y="scroll",
  padding=_DEFAULT_PADDING,
  margin=me.Margin(bottom=20),
  border_radius="10px",
  border=me.Border(
    left=_DEFAULT_BORDER_SIDE,
    right=_DEFAULT_BORDER_SIDE,
    top=_DEFAULT_BORDER_SIDE,
    bottom=_DEFAULT_BORDER_SIDE,
  ),
)
_STYLE_CHAT_INPUT = me.Style(width="100%")
_STYLE_CHAT_INPUT_BOX = me.Style(
  padding=me.Padding(top=30), display="flex", flex_direction="row"
)
_STYLE_CHAT_BUTTON = me.Style(margin=me.Margin(top=8, left=8))
_STYLE_CHAT_BUBBLE_NAME = me.Style(
  font_weight="bold",
  font_size="13px",
  padding=me.Padding(left=15, right=15, bottom=5),
)
_STYLE_CHAT_BUBBLE_PLAINTEXT = me.Style(margin=me.Margin.symmetric(vertical=15))

_LABEL_BUTTON = "send"
_LABEL_BUTTON_IN_PROGRESS = "pending"
_LABEL_INPUT = "Enter your prompt"

def _make_style_chat_bubble_wrapper(role: Role) -> me.Style:
  """Generates styles for chat bubble position.

  Args:
    role: Chat bubble alignment depends on the role
  """
  align_items = "end" if role == _ROLE_USER else "start"
  return me.Style(
    display="flex",
    flex_direction="column",
    align_items=align_items,
  )

def _make_chat_bubble_style(role: Role) -> me.Style:
  """Generates styles for chat bubble.

  Args:
    role: Chat bubble background color depends on the role
  """
  background = (
    _COLOR_CHAT_BUBBLE_YOU if role == _ROLE_USER else _COLOR_CHAT_BUBBLE_BOT
  )
  return me.Style(
    width="80%",
    font_size="16px",
    line_height="1.5",
    background=background,
    border_radius="15px",
    padding=me.Padding(right=15, left=15, bottom=3),
    margin=me.Margin(bottom=10),
    border=me.Border(
      left=_DEFAULT_BORDER_SIDE,
      right=_DEFAULT_BORDER_SIDE,
      top=_DEFAULT_BORDER_SIDE,
      bottom=_DEFAULT_BORDER_SIDE,
    ),
  )

db = None
logger = logging.getLogger()

def init_connection_pool() -> sqlalchemy.engine.base.Engine:
    """Sets up connection pool for the app."""
    # use a TCP socket when INSTANCE_HOST (e.g. 127.0.0.1) is defined
    if os.environ.get("INSTANCE_HOST"):
        return connect_tcp_socket()

    # # use the connector when INSTANCE_CONNECTION_NAME (e.g. project:region:instance) is defined
    # if os.environ.get("INSTANCE_CONNECTION_NAME"):
    #     # Either a DB_USER or a DB_IAM_USER should be defined. If both are
    #     # defined, DB_IAM_USER takes precedence.
    #     return (
    #         connect_with_connector_auto_iam_authn()
    #         if os.environ.get("DB_IAM_USER")
    #         else connect_with_connector()
    #     )

    raise ValueError(
        "Missing database connection type. Please define one of INSTANCE_HOST, INSTANCE_UNIX_SOCKET, or INSTANCE_CONNECTION_NAME"
    )
def init_db() -> sqlalchemy.engine.base.Engine:
    """Initiates connection to database and its structure."""
    global db
    if db is None:
        db = init_connection_pool()


def get_products(db: sqlalchemy.engine.base.Engine, embeddings: str) -> dict:
    products=[]
    
    stmt = sqlalchemy.text(
        """
        SELECT
                cp.product_name,
                left(cp.product_description,80) as description,
                cp.sale_price,
                cs.zip_code,
                (ce.embedding <=> (:embeddings)::vector) as distance
        FROM
                cymbal_products cp
        JOIN cymbal_embedding ce on
            ce.uniq_id = cp.uniq_id
        JOIN cymbal_inventory ci on
                ci.uniq_id=cp.uniq_id
        JOIN cymbal_stores cs on
                cs.store_id=ci.store_id
                AND ci.inventory>0
                AND cs.store_id = 1583
        ORDER BY
                distance ASC
        LIMIT 5;
        """
    )
    try:
        with db.connect() as conn:
            app_products = conn.execute(stmt, parameters={"embeddings": embeddings}).fetchall()
    except Exception as e:
        logger.exception(e)
    for row in app_products:
        products.append({"product_name":row[0],"description":row[1],"sale_price":row[2],"zip_code": row[3]})
    return products

def change_model_option(e: me.CheckboxChangeEvent):
    s = me.state(ModelDialogState)
    if e.checked:
        s.selected_models.append(e.key)
    else:
        s.selected_models.remove(e.key)

def set_gemini_api_key(e: me.InputBlurEvent):
    me.state(State).gemini_api_key = e.value

def set_openai_api_key(e: me.InputBlurEvent):
    me.state(State).openai_api_key = e.value

def set_gemma_endpoint(e: me.InputBlurEvent):
    me.state(State).gemma_endpoint_id = e.value

def set_tei_endpoint(e: me.InputBlurEvent):
    me.state(State).tei_embedding_url = e.value

def model_picker_dialog():
    state = me.state(State)
    with dialog(state.is_model_picker_dialog_open):
        with me.box(style=me.Style(display="flex", flex_direction="column", gap=12)):
            me.text("API keys")
            me.input(
                label="Gemini API Key",
                value=state.gemini_api_key,
                on_blur=set_gemini_api_key,
            )
            me.input(
                label="OpenAI API Key",
                value=state.openai_api_key,
                on_blur=set_openai_api_key,
            )
            me.input(
                label="Gemma Endpoint",
                value=state.gemma_endpoint_id,
                on_blur=set_gemma_endpoint,
            )
            me.input(
                label="TEI Endpoint",
                value=state.tei_embedding_url,
                on_blur=set_tei_endpoint,
            )
        me.text("Pick a model")
        for model in Models:
            if model.name.startswith("GEMINI"):
                disabled = not state.gemini_api_key
            elif model.name.startswith("OPENAI"):
                disabled = not state.openai_api_key
            elif model.name.startswith("GEMMA"):
                disabled = not state.gemma_endpoint_id
            else:
                disabled = False
            me.checkbox(
                key=model.value,
                label=model.value,
                checked=model.value in state.models,
                disabled=disabled,
                on_change=change_model_option,
                style=me.Style(
                    display="flex",
                    flex_direction="column",
                    gap=4,
                    padding=me.Padding(top=12),
                ),
            )
        with dialog_actions():
            me.button("Cancel", on_click=close_model_picker_dialog)
            me.button("Confirm", on_click=confirm_model_picker_dialog)

def close_model_picker_dialog(e: me.ClickEvent):
    state = me.state(State)
    state.is_model_picker_dialog_open = False

def confirm_model_picker_dialog(e: me.ClickEvent):
    dialog_state = me.state(ModelDialogState)
    state = me.state(State)
    state.is_model_picker_dialog_open = False
    state.models = dialog_state.selected_models

# Background image with light and dark theme

bimage_light = open("pub/app_background_light.png", "rb")
bimage64_light = base64.b64encode(bimage_light.read()).decode("ascii")
_BACKGROUND_IMAGE_LIGHT = f"url(data:image/png;base64,{bimage64_light}) center/cover no-repeat"

bimage_dark = open("pub/app_background_dark.png", "rb")
bimage64_dark = base64.b64encode(bimage_dark.read()).decode("ascii")
_BACKGROUND_IMAGE_DARK = f"url(data:image/png;base64,{bimage64_dark}) center/cover no-repeat"

_BACKGROUND_IMAGE = _BACKGROUND_IMAGE_LIGHT


ROOT_BOX_STYLE = me.Style(
    background = _BACKGROUND_IMAGE,
    height="100%",
    font_family="Inter",
    display="flex",
    flex_direction="column",
)



@me.page(
    path="/",
    stylesheets=[
        "https://fonts.googleapis.com/css2?family=Inter:wght@100..900&display=swap"
    ],
    title = "Cymbal store assistant"
)


def page():
    global _BACKGROUND_IMAGE
    with me.box(style=ROOT_BOX_STYLE):
        bot_user = "model"
        global db
        # initialize db within request context
        if not db:
            # initiate a connection pool to a Postgres database
            db = init_connection_pool()
        model_picker_dialog()
        def toggle_theme(e: me.ClickEvent):
            if me.theme_brightness() == "light":
                me.set_theme_mode("dark")
                _BACKGROUND_IMAGE = _BACKGROUND_IMAGE_DARK
            else:
                me.set_theme_mode("light")
                _BACKGROUND_IMAGE = _BACKGROUND_IMAGE_LIGHT


        

        def on_input_enter(e: me.InputEnterEvent):
            state = me.state(State)
            state.input = e.value
            print(state.input)
            yield from send_prompt(e)



        with me.box(style=_STYLE_APP_CONTAINER):
            with me.content_button(
                type="icon",
                style=me.Style(position="absolute", right=4, top=8),
                on_click=toggle_theme,
            ):
                me.icon("light_mode" if me.theme_brightness() == "dark" else "dark_mode")

            title = "Store Virtual Assistant"

            if title:
                me.text(title, type="headline-5", style=_STYLE_TITLE)

            with me.box(style=_STYLE_CHAT_BOX):
                state = me.state(State)
                for conversation in state.conversations:
                    for message in conversation.messages:
                        with me.box(style=_make_style_chat_bubble_wrapper(message.role)):
                            if message.role == _ROLE_ASSISTANT:
                                me.text(bot_user, style=_STYLE_CHAT_BUBBLE_NAME)
                        with me.box(style=_make_chat_bubble_style(message.role)):
                            if message.role == _ROLE_USER:
                                me.text(message.content, style=_STYLE_CHAT_BUBBLE_PLAINTEXT)
                            else:
                                me.markdown(message.content)
        

            with me.box(style=_STYLE_CHAT_INPUT_BOX):
                with me.box(style=me.Style(flex_grow=1)):
                    me.input(
                    label=_LABEL_INPUT,
                    # Workaround: update key to clear input.
                    key=f"input-{len(state.conversations)}",
                    on_blur=on_blur,
                    on_enter=on_input_enter,
                    style=_STYLE_CHAT_INPUT,
                    )
                    with me.box(
                    style=me.Style(
                        display="flex",
                        padding=me.Padding(left=12, bottom=12),
                        cursor="pointer",
                    ),
                    on_click=switch_model,
                    ):
                        me.text(
                            "Model:",
                            style=me.Style(font_weight=500, padding=me.Padding(right=6)),
                        )
                        if state.models:
                            me.text(", ".join(state.models))
                        else:
                            me.text("(no model selected)")
                with me.content_button(
                    color="primary",
                    type="flat",
                    disabled=state.in_progress,
                    on_click=send_prompt,
                    style=_STYLE_CHAT_BUTTON,
                ):
                    me.icon(
                    _LABEL_BUTTON_IN_PROGRESS if state.in_progress else _LABEL_BUTTON
                    )


def switch_model(e: me.ClickEvent):
    state = me.state(State)
    state.is_model_picker_dialog_open = True
    dialog_state = me.state(ModelDialogState)
    dialog_state.selected_models = state.models[:]


def on_blur(e: me.InputBlurEvent):
    state = me.state(State)
    state.input = e.value


def send_prompt(e: me.ClickEvent):
    state = me.state(State)
    if not state.conversations:
        for model in state.models:
            state.conversations.append(Conversation(model=model, messages=[]))
    input = state.input
    state.input = ""
    yield

    for conversation in state.conversations:
        model = conversation.model
        messages = conversation.messages
        history = messages[:]
        messages.append(ChatMessage(role="user", content=input))
        messages.append(ChatMessage(role="model", in_progress=True))
        yield

        if model == Models.GEMINI_2_0_FLASH.value:
            while True:
                intent_str = gemini_model.classify_intent(input)
                print(intent_str)
                logging.info(f"PRODUCTS LIST: {intent_str}")
                try:
                    json_intent = json.loads(intent_str)
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON: {e}")
                    continue
                break

            if json_intent["shouldRecommendProduct"] is True:
                search_embedding = gemini_model.generate_embedding(json_intent["summary"])
                products_list = get_products(db, str(search_embedding["embedding"]))
                logging.info(f"PRODUCTS LIST: {products_list}")
                print(f"PRODUCTS LIST: {products_list}")
                persona="You are friendly assistance in a store helping to find a products based on the client's request"
                safeguards="You should give information about the product, price and any supplemental information. Do not invent any new products and use for the answer the product defined in the context"
                context="""
                    Based on the client request we have loaded a list of products closely related to search.
                    The list in JSON format with list of values like {"product_name":"name","description":"some description","sale_price":10,"zip_code": 10234}
                    Here is the list of products:\n
                """+str(products_list)
                system_instruction=[persona,safeguards,context]
            else:
                persona="You are friendly assistance in a store helping to find a products based on the client's request"
                safeguards="You should give information about the product, price and any supplemental information. Do not invent any new products and use for the answer the product defined in the context"
                system_instruction=[persona,safeguards]
            llm_response = gemini_model.send_prompt_flash(input, history,system_instruction)
            
        elif model == Models.OPENAI.value:
            llm_response = openai_model.call_openai_gpt4o_mini(input, history)
        elif model == Models.GEMMA_3.value:
            while True:
                intent_str = gemma_model.classify_intent(input)
                print(intent_str)
                logging.info(f"PRODUCTS LIST: {intent_str}")
                try:
                    json_intent = json.loads(intent_str)
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON: {e}")
                    continue
                break
            if json_intent["shouldRecommendProduct"] is True:
                search_embedding = gemma_model.generate_embedding(json_intent["summary"])
                products_list = get_products(db, str(search_embedding))
                logging.info(f"PRODUCTS LIST: {products_list}")
                print(f"PRODUCTS LIST: {products_list}")
                persona="You are friendly assistance in a store helping to find a products based on the client's request"
                safeguards="You should give information about the product, price and any supplemental information. Do not invent any new products and use for the answer the product defined in the context"
                context="""
                    Based on the client request we have loaded a list of products closely related to search.
                    The list in JSON format with list of values like {"product_name":"name","description":"some description","sale_price":10,"zip_code": 10234}
                    Here is the list of products:\n
                """+str(products_list)
                system_instruction=f'{persona}{safeguards}{context}'
            else:    
                search_embedding = None
                persona="You are friendly assistance in a store helping to find a products based on the client's request"
                safeguards="You should give information about the product, price and any supplemental information. Do not invent any new products and use for the answer the product defined in the context"
                system_instruction=f'{persona}{safeguards}'
            llm_response = gemma_model.call_gemma(input, history,system_instruction)
        else:
            raise Exception("Unhandled model", model)

        for chunk in llm_response:
            messages[-1].content += chunk
            yield
        messages[-1].in_progress = False
        yield

