holo-chatbot-webui/modules/LLMService.py (134 lines of code) (raw):

# Copyright (c) Alibaba Cloud PAI. # SPDX-License-Identifier: Apache-2.0 # deling.sc import json import time import os from langchain.document_loaders import DirectoryLoader, UnstructuredFileLoader from .CustomPrompt import CustomPrompt from .EASAgent import EASAgent from .VectorDB import VectorDB from .TextSplitter import TextSplitter import nltk from .CustomLLM import CustomLLM from .QuestionPrompt import * from sentencepiece import SentencePieceProcessor class LLMService: def __init__(self, args): # assert args.upload or args.user_query, "error: dose not set any action, please set '--upload' or '--query <user_query>'." # assert os.path.exists(args.config), f"error: config path {args.config} does not exist." self.langchain_chat_history = [] self.input_tokens = [] self.llm_chat_history = [] self.sp = SentencePieceProcessor(model_file='./tokenizer.model') self.vector_db = None nltk_data_path = "/code/nltk_data" if os.path.exists(nltk_data_path): nltk.data.path = [nltk_data_path] + nltk.data.path # with open(args.config) as f: # cfg = json.load(f) # self.init_with_cfg(cfg, args) def init_with_cfg(self, cfg, args): self.cfg = cfg self.args = args # self.prompt_template = PromptTemplate(self.args) # self.eas_agent = EASAgent(self.cfg) self.vector_db = VectorDB(self.args, self.cfg) self.llm = CustomLLM() self.llm.url = self.cfg['EASCfg']['url'] self.llm.token = self.cfg['EASCfg']['token'] self.question_generator_chain = get_standalone_question_ch(self.llm) # if args.upload: # self.upload_custom_knowledge() # if args.user_query: # if args.query_type == "retrieval_llm": # self.query_func = self.query_retrieval_llm # self.query_type = "Retrieval-Augmented Generation" # elif args.query_type == "only_llm": # self.query_func = self.query_only_llm # self.query_type = "Vanilla-LLM Generation" # elif args.query_type == "only_vectorstore": # self.query_func = self.query_only_vectorstore # self.query_type = "Vector-Store Retrieval" # else: # raise ValueError(f'error: invalid query type of {args.query_type}') # answer = self.query_func(args.user_query) # print('='*20 + f' {self.query_type} ' + '='*20 + '\n', answer) def upload_custom_knowledge(self, docs_dir=None, chunk_size=200,chunk_overlap=0): if docs_dir is None: docs_dir = self.cfg['create_docs']['docs_dir'] self.cfg['create_docs']['chunk_size'] = chunk_size self.cfg['create_docs']['chunk_overlap'] = chunk_overlap self.text_splitter = TextSplitter(self.cfg) if os.path.isdir(docs_dir): docs = DirectoryLoader(docs_dir, glob=self.cfg['create_docs']['glob'], show_progress=True).load() docs = self.text_splitter.split_documents(docs) else: loader = UnstructuredFileLoader(docs_dir, mode="elements") docs = loader.load_and_split(text_splitter=self.text_splitter) print('Uploading custom knowledge.') start_time = time.time() self.vector_db.add_documents(docs) end_time = time.time() print("Insert Success. Cost time: {} s".format(end_time - start_time)) def create_user_query_prompt(self, query, topk, prompt_type, prompt=None): if topk == '' or topk is None: topk = 3 if self.vector_db is None: raise Exception('未连接向量数据库!') docs = self.vector_db.similarity_search_db(query, topk=int(topk)) if prompt_type == "General": self.args.prompt_engineering = 'general' elif prompt_type == "Extract URL": self.args.prompt_engineering = 'extract_url' elif prompt_type == "Accurate Content": self.args.prompt_engineering = 'accurate_content' elif prompt_type == "Customize": self.args.prompt_engineering = 'customize' self.prompt_template = CustomPrompt(self.args) user_prompt = self.prompt_template.get_prompt(docs, query, prompt) return user_prompt def get_new_question(self, query): if len(self.langchain_chat_history) == 0: print('result',query) return query else: result = self.question_generator_chain({"question": query, "chat_history": self.langchain_chat_history}) print('result',result) return result['text'] def checkout_history_and_summary(self, summary=False): if summary or len(self.langchain_chat_history) > 10: print("start summary") self.llm.history = self.langchain_chat_history summary_res = self.llm("请对我们之前的对话内容进行总结。") print("请对我们之前的对话内容进行总结: ", summary_res) self.langchain_chat_history = [] self.langchain_chat_history.append(("请对我们之前的对话内容进行总结。", summary_res)) self.input_tokens = [] self.input_tokens.append("请对我们之前的对话内容进行总结。") self.input_tokens.append(summary_res) return summary_res else: return "" def query_retrieval_llm(self, query, topk, prompt_type, prompt=None): new_query = self.get_new_question(query) user_prompt = self.create_user_query_prompt(new_query, topk, prompt_type, prompt) print("Post user query to EAS-LLM", user_prompt) self.llm.history = self.langchain_chat_history ans = self.llm(user_prompt) self.langchain_chat_history.append((new_query, ans)) print("Get response from EAS-LLM.") self.input_tokens.append(new_query) self.input_tokens.append(ans) tokens_len = self.sp.encode(self.input_tokens, out_type=str) lens = sum(len(tl) for tl in tokens_len) summary_res = self.checkout_history_and_summary() return ans, lens, summary_res def query_only_llm(self, query): print("Post user query to EAS-LLM") start_time = time.time() self.llm.history = self.langchain_chat_history ans = self.llm(query) self.langchain_chat_history.append((query, ans)) end_time = time.time() print("Get response from EAS-LLM. Cost time: {} s".format(end_time - start_time)) self.input_tokens.append(query) self.input_tokens.append(ans) tokens_len = self.sp.encode(self.input_tokens, out_type=str) lens = sum(len(tl) for tl in tokens_len) summary_res = self.checkout_history_and_summary() return ans, lens, summary_res def query_only_vectorstore(self, query, topk): print("Post user query to Vectore Store") if topk == '' or topk is None: topk = 3 start_time = time.time() print('query',query) docs = self.vector_db.similarity_search_db(query, topk=int(topk)) page_contents, ref_names = [], [] for idx, doc in enumerate(docs): content = doc.page_content if hasattr(doc, "page_content") else "[Doc Content Lost]" page_contents.append('='*20 + f' Doc [{idx+1}] ' + '='*20 + f'\n{content}\n') ref = doc.metadata['filename'] if hasattr(doc, "metadata") and "filename" in doc.metadata else "[Doc Name Lost]" ref_names.append(f'[{idx+1}] {ref}') ref_title = '='*20 + ' Reference Sources ' + '='*20 context_docs = '\n'.join(page_contents) + f'{ref_title}\n' + '\n'.join(ref_names) end_time = time.time() print("Get response from Vectore Store. Cost time: {} s".format(end_time - start_time)) tokens_len = self.sp.encode(context_docs, out_type=str) lens = sum(len(tl) for tl in tokens_len) return context_docs, lens