# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import streamlit as st
import vertexai
import argparse
from vertexai.generative_models import GenerativeModel
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain_core.prompts.prompt import PromptTemplate
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from vertexai.preview.generative_models import GenerativeModel

class GeminiProLLM(LLM):
    @property
    def _llm_type(self) -> str:
        return "gemini-pro"

    def _call(self,
              prompt: str,
              stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None,
              **kwargs: Any,) -> str:

        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        
        gemini_pro_model = GenerativeModel("gemini-1.5-pro")
        
        model_response = gemini_pro_model.generate_content(
            prompt, 
            generation_config={"temperature": temperature,
                               "top_p": top_p,
                               "top_k": top_k,
                               "max_output_tokens": max_output_tokens}
        )
        print(model_response)

        if len(model_response.candidates[0].content.parts) > 0:
            return model_response.candidates[0].content.parts[0].text
        else:
            return "There was an issue with returning a response. Please try again."

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"model_id": "gemini-1.5-pro", "temperature": 0.0}
    

parser = argparse.ArgumentParser(description='')
parser.add_argument('--project',required=True, help='Specify Google Cloud project')
parser.add_argument('--debug', action='store_true') # TODO: Add debug mode
parser.set_defaults(debug=False)

args = parser.parse_args()

# Initialize Vertex AI
vertexai.init(project=args.project)

# Setting page title and header
st.set_page_config(page_title="CoopBot - Powered by Gemini Pro", page_icon=":dog:",
                   initial_sidebar_state="collapsed")

st.markdown("<h1 style='text-align: center;'>CoopBot - Powered by Gemini Pro</h1>", unsafe_allow_html=True)

template = """
    You are a chatbot named CoopBot whose role is to help students understand principles of prompt design when working with Gemini Pro. You should keep a friendly and light tone, and not use complex language when it can be avoided. Keep responses brief and to the point. When asked to think step-by-step, or to explain your reasoning, please do so.

    If you are not asked to analyze a prompt, then please return just the output for the prompt. Do not analyze the prompt unless you are specfially asked to, even if asked to think step by step. 

    A well-written prompt should contain three main components: The *task* to be performed, *context* to give contextual information for completing the task. Finally there should be *examples* to show how the task should be accomplished.

    If you are asked to analyze a prompt, break the prompt into the components discussed above and the return the output at the end. If any of the components are missing, then please inform the user of this. You should also give suggestions on how to improve the prompt based on best practices of prompt design for Gemini. This response should be less than 300 words.
    
    \n\nCurrent conversation:\n{history}\nHuman: {input} \nAI:
"""

st.sidebar.title("Options")
clear_button = st.sidebar.button("Clear Conversation", key="clear")

temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.0, 0.1)
top_p = st.sidebar.slider("Top P", 0.0, 1.0, 0.95, 0.05)
top_k = st.sidebar.number_input("Top K", 1, 100, 20)
max_output_tokens = st.sidebar.number_input("Max Output Tokens", 1, 2048, 500)

# Load chat model
@st.cache_resource
def load_chain():
    llm = GeminiProLLM()
    memory = ConversationBufferMemory()
    chain = ConversationChain(llm=llm, memory=memory,
                              prompt=PromptTemplate(input_variables=['history', 'input'],      
                                                    template=template))
    return chain

chatchain = load_chain()

# Initialize session state variables
if 'messages' not in st.session_state:
    st.session_state['messages'] = []

st.markdown("""**About me**: I am a virtual assistant, powered by Gemini and Streamlit, with the goal of help people learn the fundamentals of prompt design. I am named after the author's dog, who has all knowledge known to dogs about prompt design.""", unsafe_allow_html=False)

# Reset conversation
if clear_button:
    st.session_state['messages'] = []
    chatchain.memory.clear()

# Display previous messages
for message in st.session_state['messages']:
    role = message["role"]
    content = message["content"]
    with st.chat_message(role):
        st.markdown(content)

# Chat input
prompt = st.chat_input("You:")
if prompt:
    st.session_state['messages'].append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    response = chatchain(prompt)["response"]
    
    st.session_state['messages'].append({"role": "assistant", "content": response})
    with st.chat_message("assistant"):
        st.markdown(response)
        # Do not consider previous prompts in future prompts.
        # Remove the below line to enable conversation memory.
        chatchain.memory.clear()