holo-chatbot-webui/main.py (165 lines of code) (raw):
import json
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.document_loaders import DirectoryLoader
from langchain.vectorstores import FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import AnalyticDB,Hologres,AlibabaCloudOpenSearch,AlibabaCloudOpenSearchSettings
from langchain.vectorstores import ElasticsearchStore
import os
import logging
import time
import requests
import sys
import argparse
import warnings
warnings.filterwarnings("ignore")
class LLMService:
def __init__(self, cfg) -> None:
self.cfg = cfg
self.vector_db = self.connect_db()
def post_to_chatglm2_eas(self, query_prompt):
url = self.cfg['EASCfg']['url']
token = self.cfg['EASCfg']['token']
headers = {
"Authorization": token,
'Accept': "*/*",
"Content-Type": "application/x-www-form-urlencoded;charset=utf-8"
}
resp = requests.post(
url=url,
data=query_prompt.encode('utf8'),
headers=headers,
timeout=10000,
)
return resp.text
def connect_db(self):
embedding_model = self.cfg['embedding']['embedding_model']
model_dir = self.cfg['embedding']['model_dir']
self.embed = HuggingFaceEmbeddings(model_name=os.path.join(model_dir, embedding_model), model_kwargs={'device': 'cpu'})
emb_dim = cfg['embedding']['embedding_dimension']
if 'ADBCfg' in self.cfg:
start_time = time.time()
connection_string_adb = AnalyticDB.connection_string_from_db_params(
host=self.cfg['ADBCfg']['PG_HOST'],
database='postgres',
user=self.cfg['ADBCfg']['PG_USER'],
password=self.cfg['ADBCfg']['PG_PASSWORD'],
driver='psycopg2cffi',
port=5432,
)
vector_db = AnalyticDB(
embedding_function=self.embed,
embedding_dimension=emb_dim,
connection_string=connection_string_adb,
# pre_delete_collection=True,
)
end_time = time.time()
print("Connect AnalyticDB success. Cost time: {} s".format(end_time - start_time))
elif 'HOLOCfg' in self.cfg:
start_time = time.time()
connection_string_holo = Hologres.connection_string_from_db_params(
host=self.cfg['HOLOCfg']['PG_HOST'],
port=self.cfg['HOLOCfg']['PG_PORT'],
database=self.cfg['HOLOCfg']['PG_DATABASE'],
user=self.cfg['HOLOCfg']['PG_USER'],
password=self.cfg['HOLOCfg']['PG_PASSWORD']
)
vector_db = Hologres(
embedding_function=self.embed,
ndims=emb_dim,
connection_string=connection_string_holo,
)
end_time = time.time()
print("Connect Hologres success. Cost time: {} s".format(end_time - start_time))
elif 'ElasticSearchCfg' in self.cfg:
start_time = time.time()
vector_db = ElasticsearchStore(
es_url=self.cfg['ElasticSearchCfg']['ES_URL'],
index_name=self.cfg['ElasticSearchCfg']['ES_INDEX'],
es_user=self.cfg['ElasticSearchCfg']['ES_USER'],
es_password=self.cfg['ElasticSearchCfg']['ES_PASSWORD'],
embedding=self.embed
)
end_time = time.time()
print("Connect ElasticsearchStore success. Cost time: {} s".format(end_time - start_time))
elif 'OpenSearchCfg' in self.cfg:
start_time = time.time()
print("Start Connect AlibabaCloudOpenSearch ")
settings = AlibabaCloudOpenSearchSettings(
endpoint=self.cfg['OpenSearchCfg']['endpoint'],
instance_id=self.cfg['OpenSearchCfg']['instance_id'],
datasource_name=self.cfg['OpenSearchCfg']['datasource_name'],
username=self.cfg['OpenSearchCfg']['username'],
password=self.cfg['OpenSearchCfg']['password'],
embedding_index_name=self.cfg['OpenSearchCfg']['embedding_index_name'],
field_name_mapping={
"id": self.cfg['OpenSearchCfg']['field_name_mapping']['id'],
"document": self.cfg['OpenSearchCfg']['field_name_mapping']['document'],
"embedding": self.cfg['OpenSearchCfg']['field_name_mapping']['embedding'],
"source": self.cfg['OpenSearchCfg']['field_name_mapping']['source'],
},
)
vector_db = AlibabaCloudOpenSearch(
embedding=self.embed, config=settings
)
end_time = time.time()
print("Connect AlibabaCloudOpenSearch success. Cost time: {} s".format(end_time - start_time))
else:
print("Not config any database, use FAISS-cpu default.")
vector_db = None
return vector_db
def upload_custom_knowledge(self):
docs_dir = self.cfg['create_docs']['docs_dir']
docs = DirectoryLoader(docs_dir, glob=self.cfg['create_docs']['glob'], show_progress=True).load()
text_splitter = CharacterTextSplitter(chunk_size=int(self.cfg['create_docs']['chunk_size']), chunk_overlap=self.cfg['create_docs']['chunk_overlap'])
docs = text_splitter.split_documents(docs)
print('Uploading custom knowledge.')
start_time = time.time()
if all(item not in self.cfg for item in ['ADBCfg','HOLOCfg','ElasticSearchCfg','OpenSearchCfg']):
self.vector_db = FAISS.from_documents(docs,self.embed)
self.vector_db.save_local("faiss_index")
else:
self.vector_db.add_documents(docs)
end_time = time.time()
print("Insert Success. Cost time: {} s".format(end_time - start_time))
def create_user_query_prompt(self, query):
if all(item not in self.cfg for item in ['ADBCfg','HOLOCfg','ElasticSearchCfg','OpenSearchCfg']):
self.vector_db = FAISS.load_local("faiss_index", self.embed)
docs = self.vector_db.similarity_search(query, k=int(cfg['query_topk']))
context_docs = ""
for idx, doc in enumerate(docs):
context_docs += "-----\n\n"+str(idx+1)+".\n"+doc.page_content
context_docs += "\n\n-----\n\n"
user_prompt_template = self.cfg['prompt_template']
user_prompt_template = user_prompt_template.format(context=context_docs, question=query)
return user_prompt_template
def user_query(self, query):
user_prompt_template = self.create_user_query_prompt(query)
print("Post user query to EAS-LLM")
start_time = time.time()
ans = self.post_to_chatglm2_eas(user_prompt_template)
end_time = time.time()
print("Get response from EAS-LLM. Cost time: {} s".format(end_time - start_time))
return ans
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Command line argument parser')
parser.add_argument('--config', type=str, help='json配置文件输入', default='config.json')
parser.add_argument('--upload', action='store_true', help='上传知识库', default=False)
parser.add_argument('--query', help='用户请求查询')
args = parser.parse_args()
if args.config:
if not args.upload and not args.query:
print('Not any operation is set.')
else:
if os.path.exists(args.config):
with open(args.config) as f:
cfg = json.load(f)
solver = LLMService(cfg)
if args.upload:
solver.upload_custom_knowledge()
if args.query:
answer = solver.user_query(args.query)
print("The answer is: ", answer)
else:
print(f"{args.config} does not exist.")
else:
print("The config json file must be set.")