# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
    File describing the layout of the User Interface
    using Streamlit library widgets
    Streamlit entry point function
"""
import os
import json
import configparser
from io import StringIO
from pathlib import Path
import requests
from dotenv import load_dotenv
from loguru import logger

import streamlit as st
from streamlit_modal import Modal
from st_pages import Page, show_pages

# import the executors
from utils import linear_gen_sql, cot_gen_sql, rag_gen_sql, lite_gen_sql
# import support functions
from utils import get_feedback, message_queue, \
    add_question_to_db, create_look_ui
# import auth functions
from utils import view_auth_google, view_login_google, \
    back_to_login_page, run_visualization

curr_path = Path(__file__).resolve()
# sys.path.extend([str(curr_path.parent), str(curr_path.parent.parent)])

# print(sys.path)

load_dotenv()
SHOW_SUCCESS = False
st.set_page_config(page_title='NL2SQL Studio',
                   page_icon="📊",
                   initial_sidebar_state="expanded",
                   layout='wide')

show_pages(
    [
        Page("nl2sqlstudio_ui.py", "NL2SQL Home", "🏠"),
        Page("pages/chat_agent_app.py", "Database AI", ":robot_face:"),
        Page("pages/Evaluation.py", "Evaluation", ":books:"),
    ]
)


CORE = "NL2SQL Studio Core"
LITE = "NL2SQL Studio Lite"
LINEAR = "Linear Executor"
# RAG = "Rag Executor"
COT = "Chain of Thought"
ZERO_SHOT = "Zero Shot"
FEW_SHOT = "Few Shot"
GEN_BY_CORE = "CORE_EXECUTORS"
GEN_BY_LITE = "LITE_EXECUTORS"
ZI_URL = " https://www.kaggle.com/datasets/"
ZI_URL += "jeromeblanchet/yale-universitys-spider-10-nlp-dataset/data"


def define_session_variables() -> None:
    """
        Define the session variables once at the start of the app
    """
    logger.info("Defining Session variables")
    st.session_state.messages = []
    st.session_state.question = ''
    st.session_state.new_question = False
    st.session_state.user_response = 0
    st.session_state.user_responded = False
    st.session_state.fb_count = 1
    st.session_state.refresh = True

    st.session_state.add_sample_question = False
    st.session_state.sample_question = ''
    st.session_state.sample_sql = ''
    st.session_state.add_question_status = False
    st.session_state.result_id = ''
    st.session_state.generation_engine = None

    st.session_state.sql_generated_by = None

    # st.session_state.access_token = None
    # st.session_state.token = None
    # st.session_state.login_status = False


def define_modals() -> None:
    """
        Define the Modal dialogs for Help Info display, Sample Queries
        and Project Configuration
    """
    q_s_modal = Modal(
            "Query Selector",
            key="qsm",
            # Optional
            padding=10,    # default value
            max_width=700  # default value
        )
    pc_modal = Modal('Project Configuration',
                     key='pcm',
                     padding=10,
                     max_width=545)
    qa_modal = Modal('Sample QnA',
                     key='qam',
                     padding=10,
                     max_width=545)
    info_modal = Modal('Success !!',
                       key='info',
                       padding=2,
                       max_width=200)

    st.session_state.pc_modal = pc_modal
    st.session_state.qa_modal = qa_modal
    st.session_state.info_modal = info_modal
    st.session_state.q_s_modal = q_s_modal


def define_pre_auth_layout() -> None:
    """
        Define the Login page prior to Login
    """
    def logo_image() -> None:
        """
            Display the Logo image and the Login button
        """

        left_co, cent_co, last_co = st.columns([0.35, 0.35, 0.3])
        with cent_co:
            st.image('solid_g-logo-2.png', )
            lc, cc, rc = st.columns([0.3, 0.4, 0.3])
            with cc:
                login_link()

        # st.image('solid_g-logo-2.png')

    def login_link() -> None:
        """
            Design and Link for the Login button
            called in logo_image function
        """
        st.markdown("""
            <style>
            .big-font {
                font-size:20px !important;
                background-color: royalblue;
                border-radius: 20%/50%;
                width: 105px;
                height: 1.75em;
                padding: 0px 0px 0px 30px ;
            }
            </style>
            """, unsafe_allow_html=True)
        auth_url = view_login_google()
        hyperlink_string = '<div class="big-font"> <a href="' +\
            auth_url +\
            '" style="color:white" target="_self">Login</a> </div>'
        logger.info(f"Auth url = {auth_url}")
        st.markdown(hyperlink_string, unsafe_allow_html=True)

    logo_image()


def define_post_auth_layout() -> None:
    """
        Streamlit UI layout with page configuration, styles,
        widgets, main screen and sidebar, etc
    """
    # def page_config():
    #     st.set_page_config(page_title='NL2SQL Studio',
    #                        page_icon="📊",
    #                        initial_sidebar_state="expanded",
    #                        layout='wide')

    def markdown_styles() -> None:
        st.markdown("""
            <style>
                .block-container {
                        padding-top: 0.25rem;
                        padding-bottom: 0rem;
                        padding-left: 5rem;
                        padding-right: 5rem;
                    }
            </style>
            """, unsafe_allow_html=True)

        st.markdown(
            """
            <style>
                [data-testid=stImage]{
                    text-align: center;
                    display: block;
                    margin-left: auto;
                    margin-right: auto;
                    width: 100%;
                }
            </style>
            """, unsafe_allow_html=True
        )

        st.markdown(
            """
            <style>
                [data-testid=stSidebar] [data-testid=stImage]{
                    text-align: center;
                    display: block;
                    margin-left: auto;
                    margin-right: auto;
                    width: 100%;
                }
            </style>
            """, unsafe_allow_html=True
        )

        st.markdown(
            """
            <style>
                [data-testid=stContainer] [data-testid=stImage]{
                    text-align: center;
                    display: block;
                    margin-left: auto;
                    margin-right: auto;
                    width: 100%;
                    border: 2px;
                    min-height: 30%;
                    max-height: 50%;
                }
            </style>
            """, unsafe_allow_html=True
        )

    # Side Bar definintion

    def sidebar_components() -> None:
        """
            UI Controls in the Sidebar panel
        """
        with st.sidebar.container():
            column_1, column_2 = st.columns(2)
            with column_1:
                # st.write('v1.2')
                st.image('google.png')
            with column_2:
                url = 'https://googlecloudplatform.github.io/nl2sql-studio/'
                st.markdown("[User Guide](%s)" % url)
                logout_state = st.button("Logout")
        with st.sidebar.container():
            st.write("     ")
        gen_engine = st.sidebar.selectbox(
            "Choose NL2SQL framework",
            (LITE, CORE)
            )
        logger.info(f"Generation using : {gen_engine}")
        if gen_engine == CORE:
            st.session_state.generation_engine = CORE
            with st.sidebar.container(height=140):
                st.session_state.model = st.radio('Select Prompting Technique',
                                                  [LINEAR, COT])
        elif gen_engine == LITE:
            st.session_state.generation_engine = LITE
            with st.sidebar.container(height=115):
                st.session_state.lite_model = st.radio(
                    'Select Prompting Technique',
                    [FEW_SHOT, ZERO_SHOT])
        else:
            st.session_state.generation_engine = None

        with st.sidebar.expander("Configuration Settings"):
            proj_conf = st.button("Project Configuration")
            rag_input = st.button("Questions  &  Queries", disabled=False)

        with st.sidebar.container(height=60):
            st.session_state.execution = st.checkbox(
                "Generate and Execute",
                disabled=False
                )

        if proj_conf:
            pc_modal = st.session_state.pc_modal
            pc_modal.open()

        if rag_input:
            qa_modal = st.session_state.qa_modal
            qa_modal.open()

        if logout_state:
            logger.info("Logging out")
            st.session_state.token = None
            st.session_state.login_status = False
            back_to_login_page()
            st.query_params.clear()

    def main_page() -> None:
        """
            UI Controls and interaction on the Main panel
        """
        def logo_image() -> None:
            st.image('solid_g-logo-2.png')

        def help_info() -> None:
            with st.container():
                column_1, column_2, column_3 = st.columns([0.25, 0.85, 0.1])
                with column_1:
                    # Nothing to do. Leave blank
                    pass
                with column_2:
                    # Nothing to do. Leave blank
                    pass
                with column_3:
                    st.markdown('',
                                help=f"""For the purpose of this demo we have
                                setup a demo project with id
                                'sl-test-project-363109' created a dataset in
                                BigQuery named 'spider'. This dataset
                                contains 4 tables with information that has
                                schedule of singers concert
                                [Spider Dataset]({ZI_URL}).
                                This is the default dataset to generate SQLs
                                from related natural language statements.  For
                                custom query generation, specify the Project
                                ID, Dataset and Metadata of tables in the
                                Configuration settings in the Sidebar
                                panel""")

        def input_container() -> None:
            """
                Define the Question input entry and Sample queries button
            """
            inp_container = st.container()
            with inp_container:
                column_1, column_2 = st.columns([0.86, 0.14])
                with column_2:
                    q_s = st.button('Sample Queries', key='qs_button')
                with column_1:
                    if question := st.chat_input("Enter your question here"):
                        message_queue(question)
                        st.session_state.question = question
                        st.session_state.new_question = True
                        st.session_state.user_responded = False
            st.session_state.ic = inp_container
            if q_s:
                q_s_modal = st.session_state.q_s_modal
                q_s_modal.open()

        def qa_msgs_container() -> None:
            """
                Main chat session window of the screen
            """
            msg_container_main = st.container(height=425)
            with msg_container_main:
                column_1, column_2 = st.columns([0.90, 0.10])
                with column_1:
                    msg_container = st.container()
                with column_2:
                    fb_container = st.container()

            st.session_state.fc = fb_container
            st.session_state.mc = msg_container
            get_feedback()

        def disclaimer() -> None:
            """
                Disclaimer message at the bottom of the screen
            """
            st.markdown("<p style='text-align: center; font-style: italic;\
                        font-size: 0.75rem;'>\
                        The SQL generated by this tool may be inaccurate\
                        or incomplete. Always review and test the code before\
                        executing it against your database.</p>",
                        unsafe_allow_html=True)

        def sample_queries_modal_active() -> None:
            """
                Modal output when the Sample queries button is pressed
            """
            q_s_modal = st.session_state.q_s_modal
            if q_s_modal.is_open():
                with open('sample_questions.txt',
                          'r',
                          encoding="utf-8") as input_file:
                    questions_list = input_file.readlines()

                with q_s_modal.container():
                    # st.title("Copy any sample question")
                    for question in questions_list:
                        st.code(question)

                    if st.button("Close"):
                        q_s_modal.close()

        def qa_modal_active() -> None:
            """
                Modal output when the Questions and Queries button
                on the Side bar panel is pressed
            """
            qa_modal = st.session_state.qa_modal
            if qa_modal.is_open():
                with qa_modal.container():
                    samp_question = st.text_input('Enter sample question')
                    samp_sql = st.text_input(("Enter corresponding SQL"))
                    if st.session_state.add_question_status:
                        st.success("Success ! Question added to DB ")
                    if st.button('Add question'):
                        add_question_to_db(samp_question, samp_sql,
                                           'CORE_EXECUTORS')
                        add_question_to_db(samp_question, samp_sql,
                                           'LITE_EXECUTORS')
                        info_modal = st.session_state.info_modal
                        info_modal.open()
                        qa_modal.close(True)

        def pc_modal_active() -> None:
            """
                Modal output when the Project Configuration button on the
                Side bar is pressed
            """
            pc_modal = st.session_state.pc_modal
            if pc_modal.is_open():
                with pc_modal.container():
                    project = st.text_input('Mention the GCP project name')
                    dataset = st.text_input(
                        'Specify the BigQuery dataset name'
                        )
                    uploaded_file = st.file_uploader(
                        "Choose the Metadata Cache file"
                        )
                    with open('sample_metadata.json', 'rb') as f:
                        st.download_button('Download Sample Metadata file',
                                           f,
                                           file_name='sample_metadata.json')
                    if st.button("Save configuration"):
                        if uploaded_file is not None:
                            # To read file as bytes:
                            url = os.getenv('EXECUTORS')
                            # To convert to a string based IO:
                            stringio = StringIO(
                                uploaded_file.getvalue().decode("utf-8")
                                )

                            logger.info(
                                f"Uploading file : {uploaded_file.name}"
                                )
                            # To read file as string:
                            string_data = stringio.read()
                            files = {'file': (uploaded_file.name, string_data)}
                            token = f"Bearer {st.session_state.access_token}"
                            body = {"proj_name": project,
                                    "bq_dataset": dataset,
                                    "metadata_file": uploaded_file.name}
                            headers = {"Content-type": "application/json",
                                       "Authorization": token}
                            # url = "http://localhost:5000"
                            executors_list = ['CORE_EXECUTORS',
                                              'LITE_EXECUTORS']
                            for executor in executors_list:
                                url = os.getenv(executor)
                                logger.info(
                                    f"(Project config for : {executor}"
                                    )
                                _ = requests.post(
                                    url=url+"/projconfig",
                                    data=json.dumps(body),
                                    headers=headers,
                                    timeout=None
                                    )

                                _ = requests.post(
                                    url=url+"/uploadfile",
                                    headers={"Authorization": token},
                                    files=files,
                                    timeout=None
                                    )

                        pc_modal.close()

        # Main page function calls
        main_page_functions = {
            "logo_image": logo_image,
            "help_info": help_info,
            "input_container": input_container,
            "qs_msgs_container": qa_msgs_container,
            "disclaier": disclaimer,
            "sample_queries_modal": sample_queries_modal_active,
            "qa_modal": qa_modal_active,
            "pc_modal": pc_modal_active
        }

        for _, mp_function in main_page_functions.items():
            mp_function()

        st.session_state.add_question_status = False

    # Layout function calls
    layout_functions = {
            # "page_config": page_config,
            "markdown_styles": markdown_styles,
            "sidebar_components": sidebar_components,
            "main_page": main_page
        }

    for _, layout_function in layout_functions.items():
        layout_function()


def pre_initialize() -> None:
    """
        Initialise the Application context
    """

    if 'init' not in st.session_state:
        define_session_variables()
        st.session_state.init = False

    logger.info(f"Login status = {st.session_state.login_status}")


def initialize() -> None:
    """
        Initialise the Application context
    """
    if 'init' not in st.session_state:
        define_session_variables()
        st.session_state.init = False

    define_modals()
    define_post_auth_layout()

    if "messages" not in st.session_state:
        st.session_state.messages = []


def redraw() -> None:
    """
    Trigger the re-rendering of the UI
    """
    cntr = 0
    msg_container = st.session_state.mc

    with msg_container:
        for message in st.session_state.messages:
            logger.info(f"message is: {message}")
            cntr += 1

            with st.chat_message(message["role"]):
                st.markdown(message["content"], unsafe_allow_html=True)

                if message["dataframe"] is not None:
                    visualise_modal = st.session_state.get(
                        f'visualise_modal_{cntr}',
                        Modal("Plot Results", key=f"vm_{cntr}")
                    )
                    st.session_state[f'visualise_modal_{cntr}'] = \
                        visualise_modal

                    cols = st.columns([1, 1, 1, 7])
                    with cols[0]:
                        open_modal = st.button("Default Plotting",
                                               key=f"vr_key_{cntr}")
                    with cols[1]:
                        open_modal_new = st.button(
                            "Custom Plotting", key=f"vru_key_{cntr}")

                    with cols[2]:
                        open_modal_lk = st.button(
                            "Looker Plotting", key=f"vrlk_key_{cntr}")

                    if open_modal or open_modal_new or open_modal_lk:
                        st.session_state[
                            f'v_c_{cntr}'] = (
                                'n' if open_modal else (
                                    'cp' if open_modal_new else 'clk'
                                    )
                                )
                        visualise_modal.open()

                    sql_exec_flag = st.session_state.execution
                    if visualise_modal.is_open() and sql_exec_flag:
                        with visualise_modal.container():
                            if st.session_state.get(f'v_c_{cntr}') == 'n':
                                try:
                                    run_visualization(
                                        message["dataframe"],
                                        False, cntr)
                                except Exception as e:
                                    st.write(
                                        f"Error Loading Plot \
                                        due to error: {str(e)}")
                            elif st.session_state.get(f'v_c_{cntr}') == 'cp':
                                try:
                                    run_visualization(message["dataframe"],
                                                      True, cntr)
                                except Exception as e:
                                    st.write(
                                        f"Error Loading Plot\
                                            due to error: {str(e)}")
                            elif st.session_state.get(f'v_c_{cntr}') == 'clk':
                                try:
                                    look, base_url = create_look_ui()
                                    new_look_id = look.id
                                    embed_url = (f"{base_url}/embed/looks/"
                                                 f"{new_look_id}"
                                                 "?refresh=true")
                                    st.components.v1.iframe(embed_url,
                                                            height=500)
                                except Exception as e:
                                    st.write(
                                        f"Error Loading Plot\
                                            due to error: {str(e)}")
                            if st.button("Close Modal",
                                         key=f"close_vm_key_{cntr}"):
                                st.session_state[
                                    f'v_c_{cntr}'] = None
                                visualise_modal.close()


def add_new_question() -> None:
    """
        Function that is called when a new question is added to the
        messae queue.  This will trigger the API calls to invoke the
        appropriate Executor that is selected on the Sidebar panel
    """
    if 'new_question' in st.session_state:
        redraw()
        st.session_state.refresh = False
        if st.session_state.new_question:
            if st.session_state.generation_engine == CORE:
                # if st.session_state.model == 'None':
                #     default_func(st.session_state.question)
                # elif st.session_state.model == 'Linear Executor':
                if st.session_state.model == LINEAR:
                    linear_gen_sql(st.session_state.question)
                # elif st.session_state.model == 'Chain of Thought':
                elif st.session_state.model == COT:
                    cot_gen_sql(st.session_state.question)
                else:
                    rag_gen_sql(st.session_state.question)
            elif st.session_state.generation_engine == LITE:
                lite_gen_sql(st.session_state.question)
                # if st.session_state.lite_model == ZERO_SHOT:
                #     pass
                # elif st.session_state.lite_model == FEW_SHOT:
                #     pass


def when_user_responded() -> None:
    """
        Function to capture the user feedback from the
        Thumbs up/down widget
    """
    if st.session_state.user_responded:
        st.session_state.user_responded = False
        resp = st.session_state.messages[-1]['content']
        user_feedback = 'True'\
            if st.session_state.user_response == 1 else 'False'
        if user_feedback == 'True':
            info_text = ':green[👍 User feedback captured ]'
        else:
            info_text = ':red[👎 User feedback captured ]'

        st.session_state.messages[-1]['content'] = resp + " \n\n" + info_text
        genertor_endpoint = GEN_BY_CORE \
            if st.session_state.sql_generated_by == GEN_BY_CORE\
            else GEN_BY_LITE
        url = os.getenv(genertor_endpoint) + '/userfb'
        data = {"result_id": st.session_state.result_id,
                "user_feedback": user_feedback}

        logger.info(f"User reposnse data to API {data}")
        headers = {'Content-type': 'application/json',
                   'Accept': 'text/plain',
                   "Authorization": f"Bearer {st.session_state.access_token}"}
        resp = requests.post(url=url,
                             data=json.dumps(data),
                             headers=headers,
                             timeout=None)

        st.session_state.refresh = True
        get_feedback()


def refresh() -> None:
    """
        Refresh the display
    """
    if st.session_state.refresh:
        # redraw()
        st.rerun()
    else:
        st.session_state.refresh = True
        # st.rerun()


def app_load() -> None:
    """
        On Application load
    """
    logger.info("App loaders")
    config = configparser.ConfigParser()
    config.read('config.ini')
    google_oauth = config['DEFAULT']['GOOGLE_OAUTH']
    if google_oauth == 'ENABLE':
        # Code Block for 'With Google Authentication'
        found_query_params = False
        try:
            logger.info(f"Query Parameters - {st.query_params}")
            code = st.query_params['code']
            found_query_params = True
            logger.info(f"Authorisation code : {code}")
        except Exception:
            logger.info("Login required")
            found_query_params = False

        if found_query_params:
            id_token, access_token = view_auth_google(st.query_params['code'])
            logger.info(f"ID Token = {id_token}")
            logger.info(f"Access Token = {access_token}")
            st.session_state.token = id_token
            st.session_state.access_token = access_token
            st.session_state.login_status = True
        else:
            st.session_state.token = None
            st.session_state.access_token = None
            st.session_state.login_status = False
        # Comment for 'With Google Authentication' ends
    else:
        # Code block for 'Without Google Authentication'
        st.session_state.login_status = True
        st.session_state.token = "dummy token"
        st.session_state.access_token = "dummy token"
        # Code block for without Google Auth ends

    logger.info(f"Login status = {st.session_state.login_status}")


def render_view() -> None:
    """
        Entry point Function called by Streamlit while rendering the UI
    """
    pre_auth_post_logout = {
        "pre-init": pre_initialize,
        # "app_load": app_load,
        "pre_auth_page": define_pre_auth_layout,
    }

    post_auth = {
        # "app_load": app_load,
        "initialize": initialize,
        # 'modals': define_modals,
        # "layout" : define_post_auth_layout,
        # "redraw": redraw,
        "add_new_question": add_new_question,
        "when_user_responded": when_user_responded,
        "refresh": refresh
    }

    app_load()
    funcs_to_exec = post_auth if st.session_state.login_status \
        else pre_auth_post_logout

    for _, function in funcs_to_exec.items():
        function()


render_view()
