# Copyright (c) Alibaba Cloud PAI.
# SPDX-License-Identifier: Apache-2.0
# deling.sc

import json
import time
import os
from langchain.document_loaders import DirectoryLoader, UnstructuredFileLoader
from .CustomPrompt import CustomPrompt
from .EASAgent import EASAgent
from .VectorDB import VectorDB
from .TextSplitter import TextSplitter
import nltk
from .CustomLLM import CustomLLM
from .QuestionPrompt import *
from sentencepiece import SentencePieceProcessor

class LLMService:
    def __init__(self, args):
        # assert args.upload or args.user_query, "error: dose not set any action, please set '--upload' or '--query <user_query>'."
        # assert os.path.exists(args.config), f"error: config path {args.config} does not exist."
        self.langchain_chat_history = []
        self.input_tokens = []
        self.llm_chat_history = []
        self.sp = SentencePieceProcessor(model_file='./tokenizer.model')
        self.vector_db = None
        nltk_data_path = "/code/nltk_data"
        if os.path.exists(nltk_data_path):
            nltk.data.path = [nltk_data_path] + nltk.data.path
            
        # with open(args.config) as f:
        #     cfg = json.load(f)
        # self.init_with_cfg(cfg, args)

    def init_with_cfg(self, cfg, args):
        self.cfg = cfg
        self.args = args

        # self.prompt_template = PromptTemplate(self.args)
        # self.eas_agent = EASAgent(self.cfg)
        self.vector_db = VectorDB(self.args, self.cfg)
        
        self.llm = CustomLLM()
        self.llm.url = self.cfg['EASCfg']['url']
        self.llm.token = self.cfg['EASCfg']['token']
        self.question_generator_chain = get_standalone_question_ch(self.llm)

        # if args.upload:
        #     self.upload_custom_knowledge()
        # if args.user_query:
        #     if args.query_type == "retrieval_llm":
        #         self.query_func = self.query_retrieval_llm
        #         self.query_type = "Retrieval-Augmented Generation"
        #     elif args.query_type == "only_llm":
        #         self.query_func = self.query_only_llm
        #         self.query_type = "Vanilla-LLM Generation"
        #     elif args.query_type == "only_vectorstore":
        #         self.query_func = self.query_only_vectorstore
        #         self.query_type = "Vector-Store Retrieval"
        #     else:
        #         raise ValueError(f'error: invalid query type of {args.query_type}')

        #     answer = self.query_func(args.user_query)
        #     print('='*20 + f' {self.query_type} ' + '='*20 + '\n', answer)

    def upload_custom_knowledge(self, docs_dir=None, chunk_size=200,chunk_overlap=0):
        if docs_dir is None:
            docs_dir = self.cfg['create_docs']['docs_dir']
        self.cfg['create_docs']['chunk_size'] = chunk_size
        self.cfg['create_docs']['chunk_overlap'] = chunk_overlap
        self.text_splitter = TextSplitter(self.cfg)
        if os.path.isdir(docs_dir):
            docs = DirectoryLoader(docs_dir, glob=self.cfg['create_docs']['glob'], show_progress=True).load()
            docs = self.text_splitter.split_documents(docs)
        else:
            loader = UnstructuredFileLoader(docs_dir, mode="elements")
            docs = loader.load_and_split(text_splitter=self.text_splitter)

        print('Uploading custom knowledge.')
        start_time = time.time()
        self.vector_db.add_documents(docs)
        end_time = time.time()
        print("Insert Success. Cost time: {} s".format(end_time - start_time))

    def create_user_query_prompt(self, query, topk, prompt_type, prompt=None):
        if topk == '' or topk is None:
            topk = 3
        if self.vector_db is None:
            raise Exception('未连接向量数据库！')
        docs = self.vector_db.similarity_search_db(query, topk=int(topk))
        if prompt_type == "General":
            self.args.prompt_engineering = 'general'
        elif prompt_type == "Extract URL":
            self.args.prompt_engineering = 'extract_url'
        elif prompt_type == "Accurate Content":
            self.args.prompt_engineering = 'accurate_content'
        elif prompt_type == "Customize":
            self.args.prompt_engineering = 'customize'
        self.prompt_template = CustomPrompt(self.args)
        user_prompt = self.prompt_template.get_prompt(docs, query, prompt)

        return user_prompt 

    def get_new_question(self, query):
        if len(self.langchain_chat_history) == 0:
            print('result',query)
            return query
        else:
            result = self.question_generator_chain({"question": query, "chat_history": self.langchain_chat_history})
            print('result',result)
            return result['text']

    def checkout_history_and_summary(self, summary=False):
        if summary or len(self.langchain_chat_history) > 10:
            print("start summary")
            self.llm.history = self.langchain_chat_history
            summary_res = self.llm("请对我们之前的对话内容进行总结。")
            print("请对我们之前的对话内容进行总结: ", summary_res)
            self.langchain_chat_history = []
            self.langchain_chat_history.append(("请对我们之前的对话内容进行总结。", summary_res))
            self.input_tokens = []
            self.input_tokens.append("请对我们之前的对话内容进行总结。")
            self.input_tokens.append(summary_res)
            return summary_res
        else:
            return ""
    
    def query_retrieval_llm(self, query, topk, prompt_type, prompt=None):
        new_query = self.get_new_question(query)
        user_prompt = self.create_user_query_prompt(new_query, topk, prompt_type, prompt)
        print("Post user query to EAS-LLM", user_prompt)
        self.llm.history = self.langchain_chat_history
        ans = self.llm(user_prompt)
        self.langchain_chat_history.append((new_query, ans))
        print("Get response from EAS-LLM.")
        self.input_tokens.append(new_query)
        self.input_tokens.append(ans)
        tokens_len = self.sp.encode(self.input_tokens, out_type=str)
        lens = sum(len(tl) for tl in tokens_len)
        summary_res = self.checkout_history_and_summary()
        return ans, lens, summary_res

    def query_only_llm(self, query):
        print("Post user query to EAS-LLM")
        start_time = time.time()
        self.llm.history = self.langchain_chat_history
        ans = self.llm(query)
        self.langchain_chat_history.append((query, ans))
        end_time = time.time()
        print("Get response from EAS-LLM. Cost time: {} s".format(end_time - start_time))
        self.input_tokens.append(query)
        self.input_tokens.append(ans)
        tokens_len = self.sp.encode(self.input_tokens, out_type=str)
        lens = sum(len(tl) for tl in tokens_len)
        summary_res = self.checkout_history_and_summary()
        return ans, lens, summary_res

    def query_only_vectorstore(self, query, topk):
        print("Post user query to Vectore Store")
        if topk == '' or topk is None:
            topk = 3
        start_time = time.time()
        print('query',query)
        docs = self.vector_db.similarity_search_db(query, topk=int(topk))

        page_contents, ref_names = [], []
        for idx, doc in enumerate(docs):
            content = doc.page_content if hasattr(doc, "page_content") else "[Doc Content Lost]"
            page_contents.append('='*20 + f' Doc [{idx+1}] ' + '='*20 + f'\n{content}\n')
            ref = doc.metadata['filename'] if hasattr(doc, "metadata") and "filename" in doc.metadata else "[Doc Name Lost]"
            ref_names.append(f'[{idx+1}] {ref}')

        ref_title = '='*20 + ' Reference Sources ' + '='*20
        context_docs = '\n'.join(page_contents) + f'{ref_title}\n' + '\n'.join(ref_names)
        end_time = time.time()
        print("Get response from Vectore Store. Cost time: {} s".format(end_time - start_time))
        tokens_len = self.sp.encode(context_docs, out_type=str)
        lens = sum(len(tl) for tl in tokens_len)
        return context_docs, lens
