holo-chatbot-webui/modules/UI.py (204 lines of code) (raw):
import gradio as gr
from modules.LLMService import LLMService
import time
import os
import json
import sys
import gradio
def html_path(filename):
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
return os.path.join(script_path, "html", filename)
def html(filename):
path = html_path(filename)
if os.path.exists(path):
with open(path, encoding="utf8") as file:
return file.read()
return ""
def webpath(fn):
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
if fn.startswith(script_path):
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
else:
web_path = os.path.abspath(fn)
return f'file={web_path}?{os.path.getmtime(fn)}'
def css_html():
head = ""
def stylesheet(fn):
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
cssfile = "style.css"
if not os.path.isfile(cssfile):
print("cssfile not exist")
head += stylesheet(cssfile)
return head
def reload_javascript():
css = css_html()
GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
def template_response(*args, **kwargs):
res = GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
res.init_headers()
return res
gradio.routes.templates.TemplateResponse = template_response
def create_ui(service,_global_args,_global_cfg):
reload_javascript()
def connect_holo(emb_model, emb_dim, eas_url, eas_token, pg_host, pg_port, pg_database, pg_user, pg_pwd):
cfg = {
'embedding': {
"embedding_model": emb_model,
"model_dir": "./embedding_model/",
"embedding_dimension": emb_dim
},
'EASCfg': {
"url": eas_url,
"token": eas_token
},
'HOLOCfg': {
"PG_HOST": pg_host,
"PG_DATABASE": pg_database,
"PG_PORT": int(pg_port),
"PG_USER": pg_user,
"PG_PASSWORD": pg_pwd
},
"create_docs":{
"chunk_size": 200,
"chunk_overlap": 0,
"docs_dir": "docs/",
"glob": "**/*"
}
}
_global_args.vectordb_type = "Hologres"
_global_cfg.update(cfg)
try:
service.init_with_cfg(_global_cfg, _global_args)
return "连接 Hologres 成功"
except Exception as e:
return str(e)
with gr.Blocks() as demo:
value_md = """
# <center> \N{fire} Hologres + 大模型搭建企业级问答知识库!
<center>
\N{rocket} [Hologres产品介绍](https://www.aliyun.com/product/bigdata/hologram)
/ \N{rocket} [Hologres向量计算](https://help.aliyun.com/zh/hologres/user-guide/vector-processing-based-on-proxima)
/ \N{rocket} [Hologres向量Python SDK](https://help.aliyun.com/zh/hologres/developer-reference/vector-computing-sdk)
/ \N{rocket} [PAI-EAS模型在线服务](https://pai.console.aliyun.com)
/ \N{rocket} [通义千问](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary)
\N{fire} 请加入[钉钉群聊](https://help.aliyun.com/zh/hologres/support/obtain-online-support-for-hologres)获取更多在线支持
\N{whale} [通过API调用](/docs)
"""
gr.Markdown(value=value_md)
with gr.Tab("\N{hammer} 设置"):
with gr.Row():
with gr.Column():
with gr.Column():
md_emb = gr.Markdown(value="**请选择 embedding 模型**")
emb_model = gr.Dropdown(["text2vec-base-chinese", "SGPT-125M-weightedmean-nli-bitfit"], label="Emebdding Model", value=_global_args.embed_model)
emb_dim = gr.Textbox(label="Emebdding Dimension", value=_global_args.embed_dim)
def change_emb_model(model):
if model == "SGPT-125M-weightedmean-nli-bitfit":
return {emb_dim: gr.update(value="768")}
if model == "text2vec-base-chinese":
return {emb_dim: gr.update(value="768")}
emb_model.change(fn=change_emb_model, inputs=emb_model, outputs=[emb_dim])
with gr.Column():
md_eas = gr.Markdown(value="**请填入模型在线服务PAI-EAS连接信息**")
eas_url = gr.Textbox(label="EAS Url", value=_global_cfg['EASCfg']['url'])
eas_token = gr.Textbox(label="EAS Token", value=_global_cfg['EASCfg']['token'])
with gr.Column():
md_vs = gr.Markdown(value="**请输入Hologres数据库用户名和密码**")
with gr.Column(visible=(_global_cfg['vector_store']=="Hologres")) as holo_col:
holo_host = gr.Textbox(label="Host",
value=_global_cfg['HOLOCfg']['PG_HOST'] if _global_cfg['vector_store']=="Hologres" else '')
holo_port = gr.Textbox(label="Host",
value=_global_cfg['HOLOCfg']['PG_PORT'] if _global_cfg['vector_store']=="Hologres" else '')
holo_database = gr.Textbox(label="Database",
value=_global_cfg['HOLOCfg']['PG_DATABASE'] if _global_cfg['vector_store']=="Hologres" else '')
holo_user = gr.Textbox(label="User",
value=_global_cfg['HOLOCfg']['PG_USER'] if _global_cfg['vector_store']=="Hologres" else '')
holo_pwd= gr.Textbox(label="Password",
value=_global_cfg['HOLOCfg']['PG_PASSWORD'] if _global_cfg['vector_store']=="Hologres" else '')
connect_btn = gr.Button("连接 Hologres", variant="primary")
con_state = gr.Textbox(label="连接信息: ")
connect_btn.click(fn=connect_holo, inputs=[emb_model, emb_dim, eas_url, eas_token, holo_host, holo_port, holo_database, holo_user, holo_pwd], outputs=con_state, api_name="connect_holo")
def change_ds_conn(radio):
return {holo_col: gr.update(visible=True)}
with gr.Tab("📃 上传"):
with gr.Row():
with gr.Column(scale=2):
chunk_size = gr.Textbox(label="\N{rocket} 块大小(将文档划分成的块的大小)",value='200')
chunk_overlap = gr.Textbox(label="\N{fire} 块重叠大小(相邻文档块彼此重叠的部分)",value='0')
with gr.Column(scale=8):
with gr.Tab("上传文件"):
upload_file = gr.File(label="上传知识库文件 (支持的文件类型: txt, md, doc, docx, pdf)",
file_types=['.txt', '.md', '.docx', '.pdf'], file_count="multiple")
connect_btn = gr.Button("上传", variant="primary")
state_hl_file = gr.Textbox(label="状态")
with gr.Tab("上传目录"):
upload_file_dir = gr.File(label="上传一个包含知识库文件的文件夹 (支持的文件类型: txt, md, docx, pdf)" , file_count="directory")
connect_dir_btn = gr.Button("上传", variant="primary")
state_hl_dir = gr.Textbox(label="状态")
def upload_knowledge(upload_file,chunk_size,chunk_overlap):
file_name = ''
for file in upload_file:
if file.name.lower().endswith(".txt") or file.name.lower().endswith(".md") or file.name.lower().endswith(".docx") or file.name.lower().endswith(".doc") or file.name.lower().endswith(".pdf"):
file_path = file.name
file_name += file.name.rsplit('/', 1)[-1] + ', '
service.upload_custom_knowledge(file_path,int(chunk_size),int(chunk_overlap))
return "成功上传 " + str(len(upload_file)) + " 个文件 [ " + file_name + "] ! \n \n 相关内容已成功编码并上传至向量数据库,您现在可以开始聊天了!"
def upload_knowledge_dir(upload_dir,chunk_size,chunk_overlap):
for file in upload_dir:
if file.name.lower().endswith(".txt") or file.name.lower().endswith(".md") or file.name.lower().endswith(".docx") or file.name.lower().endswith(".doc") or file.name.lower().endswith(".pdf"):
file_path = file.name
service.upload_custom_knowledge(file_path,chunk_size,chunk_overlap)
return "成功上传 " + str(len(upload_dir)) + " 个文件!"
connect_btn.click(fn=upload_knowledge, inputs=[upload_file,chunk_size,chunk_overlap], outputs=state_hl_file, api_name="upload_knowledge")
connect_dir_btn.click(fn=upload_knowledge_dir, inputs=[upload_file_dir,chunk_size,chunk_overlap], outputs=state_hl_dir, api_name="upload_knowledge_dir")
with gr.Tab("💬 聊天"):
with gr.Row():
with gr.Column(scale=2):
ds_radio = gr.Radio(
[ "向量数据库", "大语言模型", "向量数据库+大语言模型"], label="💬 选择聊天模式"
)
topk = gr.Textbox(label="查询最相关的k条语料",value='3')
with gr.Column():
prm_radio = gr.Radio(
[ "通用", "URL提取", "自定义"], label="\N{rocket} 请选择prompt模板"
)
prompt = gr.Textbox(label="Prompt", placeholder="在此处填入prompt模板,上下文和问题用{content}和{question}表示", lines=4)
def change_prompt_template(prm_radio):
if prm_radio == "通用":
return {prompt: gr.update(value="基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 \"根据已知信息无法回答该问题\" 或 \"没有提供足够的相关信息\",不允许在答案中添加编造成分,答案请使用中文。\n=====\n已知信息:\n{context}\n=====\n用户问题:\n{question}")}
elif prm_radio == "URL提取":
return {prompt: gr.update(value="你是一位智能小助手,请根据下面我所提供的相关知识,对我提出的问题进行回答。回答的内容必须包括其定义、特征、应用领域以及相关网页链接等等内容,同时务必满足下方所提的要求!\n=====\n 知识库相关知识如下:\n{context}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {question}")}
elif prm_radio == "自定义":
return {prompt: gr.update(value="")}
prm_radio.change(fn=change_prompt_template, inputs=prm_radio, outputs=[prompt])
cur_tokens = gr.Textbox(label="\N{fire} 当前token总数")
with gr.Column(scale=8):
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(label="在此处提问")
with gr.Row():
submitBtn = gr.Button("提交", variant="primary")
summaryBtn = gr.Button("总结", variant="primary")
clear_his = gr.Button("清空对话", variant="secondary")
def respond(message, chat_history, ds_radio, topk, prm_radio, prompt):
summary_res = ""
if ds_radio == "向量数据库":
answer, lens = service.query_only_vectorstore(message,topk)
elif ds_radio == "大语言模型":
answer, lens, summary_res = service.query_only_llm(message)
else:
answer, lens, summary_res = service.query_retrieval_llm(message,topk, prm_radio, prompt)
bot_message = answer
chat_history.append((message, bot_message))
time.sleep(0.05)
return "", chat_history, str(lens) + "\n" + summary_res
def clear_hisoty(chat_history):
chat_history = []
service.langchain_chat_history = []
service.input_tokens = []
# chat_history.append(('Clear the chat history', bot_message))
time.sleep(0.05)
return chat_history, "0 \n 成功清空!"
def summary_hisoty(chat_history):
service.input_tokens = []
bot_message = service.checkout_history_and_summary(summary=True)
chat_history.append(('请对我们之前的对话内容进行总结。', bot_message))
tokens_len = service.sp.encode(service.input_tokens, out_type=str)
lens = sum(len(tl) for tl in tokens_len)
time.sleep(0.05)
return chat_history, str(lens) + "\n" + bot_message
submitBtn.click(respond, [msg, chatbot, ds_radio, topk, prm_radio, prompt], [msg, chatbot, cur_tokens])
clear_his.click(clear_hisoty,[chatbot],[chatbot, cur_tokens])
summaryBtn.click(summary_hisoty,[chatbot],[chatbot, cur_tokens])
footer = html("footer.html")
gr.HTML(footer, elem_id="footer")
return demo