holo-llm-deepseek/main.py (161 lines of code) (raw):

from typing import Any, List from langchain_community.document_loaders import CSVLoader from langchain_community.embeddings import ModelScopeEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Hologres import requests from typing import List import os import json import time import argparse class LLMChatbot: def __init__(self, config, clear_db) -> None: self.config = config self.embeddings = ModelScopeEmbeddings( model_id=self.config['embedding']['model_id']) self.vectorstore = self.connect_hologres(clear_db) def connect_hologres(self, clear_db): print("start connecting") HOLO_ENDPOINT = self.config['holo_config']['HOLO_ENDPOINT'] HOLO_PORT = self.config['holo_config']['HOLO_PORT'] HOLO_DATABASE = self.config['holo_config']['HOLO_DATABASE'] HOLO_USER = self.config['holo_config']['HOLO_USER'] HOLO_PASSWORD = self.config['holo_config']['HOLO_PASSWORD'] connection_string = Hologres.connection_string_from_db_params( HOLO_ENDPOINT, int(HOLO_PORT), HOLO_DATABASE, HOLO_USER, HOLO_PASSWORD) vectorstore = Hologres( connection_string=connection_string, embedding_function=self.embeddings, ndims=768, table_name='langchain_embedding', pre_delete_table=clear_db) return vectorstore def load_db(self, files: List[str]) -> None: # read docs documents = [] for fname in files: loader = CSVLoader(fname) documents += loader.load() # split docs text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=100) documents = text_splitter.split_documents(documents) # store embedding in vectorstore start_time = time.time() self.vectorstore.add_documents(documents) end_time = time.time() print( "Store embedding into Hologres Success.Cost Time: {:.2f}s".format( end_time - start_time)) def generate_context(self, question: str, max_context_length: int) -> str: docs = self.vectorstore.similarity_search( question, k=self.config['query_topk']) # Limit the total length of context current_context_length = 0 ret = [] for doc in docs: if len(doc.page_content) + \ current_context_length > max_context_length: continue current_context_length += len(doc.page_content) ret.append(doc.page_content) return ret def post_requests_to_deepseek_eas(self, query_prompt: str): url = self.config['eas_config']['url'] token = self.config['eas_config']['token'] stream = True if self.config['eas_config']['stream_mode'] == 1 else False temperature = self.config['eas_config']['temperature'] top_p = self.config['eas_config']['top_p'] top_k = self.config['eas_config']['top_k'] max_tokens = self.config['eas_config']['max_tokens'] headers = { "Content-Type": "application/json", "Authorization": token, } messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": query_prompt}, ] req = { "messages": messages, "stream": stream, "temperature": temperature, "top_p": top_p, "top_k": top_k, "max_tokens": max_tokens, } response = requests.post( url, json=req, headers=headers, stream=stream, ) if stream: for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False): msg = chunk.decode("utf-8") if msg.startswith("data"): info = msg[6:] if info == "[DONE]": break else: resp = json.loads(info) print(resp["choices"][0]["delta"] ["content"], end="", flush=True) else: resp = json.loads(response.text) print(resp["choices"][0]["message"]["content"]) def query(self, question: str, use_holo: bool = True): message_list = self.generate_context(question, 1800) context = '' if use_holo: for i in range(len(message_list)): pos = message_list[i].find('content:') context = context + message_list[i][pos + 9:-1] prompt_template = self.config['prompt_template'] prompt_query = prompt_template.format( context=context, question=question) start_time = time.time() answer = self.post_requests_to_deepseek_eas(prompt_query) end_time = time.time() print("\nGet response from PAI-EAS cost {:.2f} seconds\n".format( end_time - start_time)) if __name__ == '__main__': parser = argparse.ArgumentParser( prog='chatbot', description='holo chatbot command line interface') parser.add_argument('-l', '--load', action='store_true', help='generate embeddings and update the vector database.') parser.add_argument('-f', '--files', nargs='*', default=[], help='specify the csv data file to update. If leave empty, all files in ./data will be updated. Only valid when --load is set.') parser.add_argument('--clear', action='store_true', help='clear all data in vector store') parser.add_argument('-n', '--no-vector-store', action='store_true', help='run pure PAI-LLM without vector store') parser.add_argument( '--config', help='input configuration json file', default='./config/config.json') args = parser.parse_args() if args.config: if os.path.exists(args.config): with open(args.config) as f: config = json.load(f) bot = LLMChatbot(config, args.clear) if args.load: files = args.files if len(files) == 0: DIR_PATH = os.path.dirname(os.path.realpath(__file__)) files = [os.path.join(DIR_PATH, 'data', x) for x in os.listdir(os.path.join(DIR_PATH, 'data'))] print(f'start loading files: {files}') bot.load_db(files) exit(0) # Start Question while True: print("Please enter a Question: ") question = input() if (args.no_vector_store): print('PAI-LLM answer:\n ') bot.query(question, False) else: print('PAI-LLM + Hologres answer:\n ') bot.query(question, True) else: print(f"{args.config} is not existed.") else: print("The config json file must be set.")