holo-chatbot-webui/webui.py (108 lines of code) (raw):

# Copyright (c) Alibaba Cloud PAI. # SPDX-License-Identifier: Apache-2.0 # deling.sc from fastapi import FastAPI, File, UploadFile import gradio as gr from modules.LLMService import LLMService import time import os from pydantic import BaseModel import json from args import parse_args from modules.UI import * def init_args(args): args.prompt_engineering = 'general' args.embed_model = "text2vec-base-chinese" args.embed_dim = 768 # args.vectordb_type = 'Elasticsearch' args.upload = False # args.user_query = None # args.query_type = "retrieval_llm" def get_env(key): assert key in os.environ, f"please set the environment variable {key}" return os.environ[key] _global_args = parse_args() init_args(_global_args) service = LLMService(_global_args) _global_cfg = { "embedding": { "model_dir": "embedding_model/", "embedding_model": "text2vec-base-chinese", "embedding_dimension": 768 }, "vector_store": "Hologres", "EASCfg": { "url": get_env('EAS_URL'), "token": get_env('EAS_TOKEN'), }, "HOLOCfg": { "PG_HOST": get_env('HOLO_HOST'), "PG_PORT": get_env('HOLO_PORT'), "PG_DATABASE": get_env('HOLO_DATABASE'), "PG_USER": "BASIC$chatbot", "PG_PASSWORD": "xxx", }, "query_topk": 4, "prompt_template": "基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 \"根据已知信息无法回答该问题\" 或 \"没有提供足够的相关信息\",不允许在答案中添加编造成分,答案请使用中文。\n=====\n已知信息:\n{context}\n=====\n用户问题:\n{question}" } print("_global_cfg:", _global_cfg) class Query(BaseModel): question: str topk: int | None = None prompt: str | None = None host_ = "127.0.0.1" app = FastAPI(host=host_) @app.post("/chat/llm") async def query_by_llm(query: Query): ans, lens, _ = service.query_only_llm(query.question) return {"response": ans, "tokens": lens} @app.post("/chat/vectorstore") async def query_by_vectorstore(query: Query): ans, lens = service.query_only_vectorstore(query.question,query.topk) return {"response": ans, "tokens": lens} @app.post("/chat/langchain") async def query_by_langchain(query: Query): ans, lens, _ = service.query_retrieval_llm(query.question,query.topk,query.prompt) return {"response": ans, "tokens": lens} @app.post("/uploadfile") async def create_upload_file(file: UploadFile | None = None): if not file: return {"message": "No upload file sent"} else: fn = file.filename save_path = f'./file/' if not os.path.exists(save_path): os.mkdir(save_path) save_file = os.path.join(save_path, fn) f = open(save_file, 'wb') data = await file.read() f.write(data) f.close() service.upload_custom_knowledge(f.name,200,0) return {"response": "success"} @app.post("/config") async def create_config_json_file(file: UploadFile | None = None): if not file: return {"message": "No upload config json file sent"} else: fn = file.filename save_path = f'./config/' if not os.path.exists(save_path): os.mkdir(save_path) save_file = os.path.join(save_path, fn) f = open(save_file, 'wb') data = await file.read() f.write(data) f.close() with open(f.name) as c: cfg = json.load(c) _global_args.embed_model = cfg['embedding']['embedding_model'] _global_args.vectordb_type = cfg['vector_store'] if 'query_topk' not in cfg: cfg['query_topk'] = 4 if 'prompt_template' not in cfg: cfg['prompt_template'] = "基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 \"根据已知信息无法回答该问题\" 或 \"没有提供足够的相关信息\",不允许在答案中添加编造成分,答案请使用中文。\n=====\n已知信息:\n{context}\n=====\n用户问题:\n{question}" if cfg.get('create_docs') is None: cfg['create_docs'] = {} cfg['create_docs']['chunk_size'] = 200 cfg['create_docs']['chunk_overlap'] = 0 cfg['create_docs']['docs_dir'] = 'docs/' cfg['create_docs']['glob'] = "**/*" connect_time = service.init_with_cfg(cfg,_global_args) return {"response": "success"} ui = create_ui(service,_global_args,_global_cfg) app = gr.mount_gradio_app(app, ui, path='')