holo-chatbot-webui/webui.py (108 lines of code) (raw):
# Copyright (c) Alibaba Cloud PAI.
# SPDX-License-Identifier: Apache-2.0
# deling.sc
from fastapi import FastAPI, File, UploadFile
import gradio as gr
from modules.LLMService import LLMService
import time
import os
from pydantic import BaseModel
import json
from args import parse_args
from modules.UI import *
def init_args(args):
args.prompt_engineering = 'general'
args.embed_model = "text2vec-base-chinese"
args.embed_dim = 768
# args.vectordb_type = 'Elasticsearch'
args.upload = False
# args.user_query = None
# args.query_type = "retrieval_llm"
def get_env(key):
assert key in os.environ, f"please set the environment variable {key}"
return os.environ[key]
_global_args = parse_args()
init_args(_global_args)
service = LLMService(_global_args)
_global_cfg = {
"embedding": {
"model_dir": "embedding_model/",
"embedding_model": "text2vec-base-chinese",
"embedding_dimension": 768
},
"vector_store": "Hologres",
"EASCfg": {
"url": get_env('EAS_URL'),
"token": get_env('EAS_TOKEN'),
},
"HOLOCfg": {
"PG_HOST": get_env('HOLO_HOST'),
"PG_PORT": get_env('HOLO_PORT'),
"PG_DATABASE": get_env('HOLO_DATABASE'),
"PG_USER": "BASIC$chatbot",
"PG_PASSWORD": "xxx",
},
"query_topk": 4,
"prompt_template": "基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 \"根据已知信息无法回答该问题\" 或 \"没有提供足够的相关信息\",不允许在答案中添加编造成分,答案请使用中文。\n=====\n已知信息:\n{context}\n=====\n用户问题:\n{question}"
}
print("_global_cfg:", _global_cfg)
class Query(BaseModel):
question: str
topk: int | None = None
prompt: str | None = None
host_ = "127.0.0.1"
app = FastAPI(host=host_)
@app.post("/chat/llm")
async def query_by_llm(query: Query):
ans, lens, _ = service.query_only_llm(query.question)
return {"response": ans, "tokens": lens}
@app.post("/chat/vectorstore")
async def query_by_vectorstore(query: Query):
ans, lens = service.query_only_vectorstore(query.question,query.topk)
return {"response": ans, "tokens": lens}
@app.post("/chat/langchain")
async def query_by_langchain(query: Query):
ans, lens, _ = service.query_retrieval_llm(query.question,query.topk,query.prompt)
return {"response": ans, "tokens": lens}
@app.post("/uploadfile")
async def create_upload_file(file: UploadFile | None = None):
if not file:
return {"message": "No upload file sent"}
else:
fn = file.filename
save_path = f'./file/'
if not os.path.exists(save_path):
os.mkdir(save_path)
save_file = os.path.join(save_path, fn)
f = open(save_file, 'wb')
data = await file.read()
f.write(data)
f.close()
service.upload_custom_knowledge(f.name,200,0)
return {"response": "success"}
@app.post("/config")
async def create_config_json_file(file: UploadFile | None = None):
if not file:
return {"message": "No upload config json file sent"}
else:
fn = file.filename
save_path = f'./config/'
if not os.path.exists(save_path):
os.mkdir(save_path)
save_file = os.path.join(save_path, fn)
f = open(save_file, 'wb')
data = await file.read()
f.write(data)
f.close()
with open(f.name) as c:
cfg = json.load(c)
_global_args.embed_model = cfg['embedding']['embedding_model']
_global_args.vectordb_type = cfg['vector_store']
if 'query_topk' not in cfg:
cfg['query_topk'] = 4
if 'prompt_template' not in cfg:
cfg['prompt_template'] = "基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 \"根据已知信息无法回答该问题\" 或 \"没有提供足够的相关信息\",不允许在答案中添加编造成分,答案请使用中文。\n=====\n已知信息:\n{context}\n=====\n用户问题:\n{question}"
if cfg.get('create_docs') is None:
cfg['create_docs'] = {}
cfg['create_docs']['chunk_size'] = 200
cfg['create_docs']['chunk_overlap'] = 0
cfg['create_docs']['docs_dir'] = 'docs/'
cfg['create_docs']['glob'] = "**/*"
connect_time = service.init_with_cfg(cfg,_global_args)
return {"response": "success"}
ui = create_ui(service,_global_args,_global_cfg)
app = gr.mount_gradio_app(app, ui, path='')