model-gallery/deploy/mllm/webui_client.py (278 lines of code) (raw):
import copy
import re
import gradio as gr
import base64
from argparse import ArgumentParser
def _get_args():
parser = ArgumentParser()
parser.add_argument("--eas_endpoint", type=str, required=True)
parser.add_argument("--eas_token", type=str, required=True)
parser.add_argument(
"--share",
action="store_true",
default=False,
help="Create a publicly shareable link for the interface.",
)
parser.add_argument(
"--inbrowser",
action="store_true",
default=False,
help="Automatically launch the interface in a new tab on the default browser.",
)
parser.add_argument(
"--server-port", type=int, default=7860, help="Demo server port."
)
parser.add_argument(
"--server-name", type=str, default="127.0.0.1", help="Demo server name."
)
args = parser.parse_args()
return args
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = "<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def _remove_image_special(text):
text = text.replace("<ref>", "").replace("</ref>", "")
return re.sub(r"<box>.*?(</box>|$)", "", text)
def is_video_file(filename):
video_extensions = [
".mp4",
".avi",
".mkv",
".mov",
".wmv",
".flv",
".webm",
".mpeg",
]
return any(filename.lower().endswith(ext) for ext in video_extensions)
def encode_base64_content_from_file(file_path):
with open(file_path, "rb") as file:
encoded_string = base64.b64encode(file.read()).decode("utf-8")
return encoded_string
def transform_messages(original_messages):
transformed_messages = []
for message in original_messages:
new_content = []
for item in message["content"]:
if "image_url" in item:
new_item = {
"type": "image_url",
"image_url": {"url": item["image_url"]},
}
elif "text" in item:
new_item = {"type": "text", "text": item["text"]}
elif "video_url" in item:
new_item = {
"type": "video_url",
"video_url": {"url": item["video_url"]},
}
else:
continue
new_content.append(new_item)
new_message = {"role": message["role"], "content": new_content}
transformed_messages.append(new_message)
return transformed_messages
def _launch_ui(model, client, args):
def call_local_model(messages, max_tokens, temperature, timeout, use_stream):
try:
messages = transform_messages(messages)
print("Messages:", messages)
print(
f"Arguments: max_tokens={max_tokens}, temperature={temperature}, timeout={timeout}, stream={use_stream}"
)
gen = client.chat.completions.create(
messages=messages,
model=model,
max_tokens=max_tokens,
temperature=temperature,
timeout=timeout,
stream=use_stream,
)
if use_stream:
generated_text = ""
for chunk in gen:
generated_text += chunk.choices[0].delta.content
yield generated_text
else:
generated_text = gen.choices[0].message.content
yield generated_text
except Exception as e:
print(e)
raise gr.Error(e.message)
def create_predict_fn():
def predict(
_chatbot, task_history, max_tokens, temperature, timeout, use_stream
):
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
if len(chat_query) == 0:
_chatbot.pop()
task_history.pop()
return _chatbot
print("User: " + _parse_text(query))
history_cp = copy.deepcopy(task_history)
full_response = ""
messages = []
content = []
for cls, q, a in history_cp:
if cls == "image":
content.append({"image_url": f"data:image/jpeg;base64,{q}"})
elif cls == "video":
content.append({"video_url": f"data:video/mp4;base64,{q}"})
else:
content.append({"text": q})
messages.append({"role": "user", "content": content})
messages.append({"role": "assistant", "content": [{"text": a}]})
content = []
messages.pop()
for response in call_local_model(
messages, max_tokens, temperature, timeout, use_stream
):
_chatbot[-1] = (
_parse_text(chat_query),
_remove_image_special(_parse_text(response)),
)
yield _chatbot
full_response = _parse_text(response)
task_history[-1] = ("text", query, full_response)
print("VL-Chat: " + _parse_text(full_response))
yield _chatbot
return predict
def create_regenerate_fn():
def regenerate(
_chatbot, task_history, max_tokens, temperature, timeout, use_stream
):
print("Regenerating...")
if not task_history:
return _chatbot
item = task_history[-1]
if item[1] is None:
return _chatbot
task_history[-1] = (item[0], None)
chatbot_item = _chatbot.pop(-1)
if chatbot_item[0] is None:
_chatbot[-1] = (_chatbot[-1][0], None)
else:
_chatbot.append((chatbot_item[0], None))
_chatbot_gen = predict(
_chatbot, task_history, max_tokens, temperature, timeout, use_stream
)
for _chatbot in _chatbot_gen:
yield _chatbot
return regenerate
predict = create_predict_fn()
regenerate = create_regenerate_fn()
def add_text(history, task_history, text):
task_text = text
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [(_parse_text(text), None)]
task_history = task_history + [("text", task_text, None)]
return history, task_history
def add_file(history, task_history, file):
history = history if history is not None else []
task_history = task_history if task_history is not None else []
base64_encoded_content = encode_base64_content_from_file(file.name)
history = history + [((file.name,), None)]
if is_video_file(file.name):
task_history = task_history + [("video", base64_encoded_content, None)]
else:
task_history = task_history + [("image", base64_encoded_content, None)]
return history, task_history
def reset_user_input():
return gr.update(value="")
def reset_state(task_history):
print("Clear history.")
task_history.clear()
return []
with gr.Blocks() as demo:
gr.Markdown(f"""<center><font size=8>MLLM-WebUI</center>""")
chatbot = gr.Chatbot(elem_classes="control-height", height=480)
query = gr.Textbox(lines=2, label="Input")
task_history = gr.State([])
with gr.Row():
max_tokens = gr.Slider(
minimum=10, maximum=10240, step=10, label="max_tokens", value=1024
)
temperature = gr.Slider(
minimum=0.0, maximum=2.0, step=0.01, label="temperature", value=0.0
)
timeout = gr.Slider(
minimum=0, maximum=600, step=10, label="timeout(seconds)", value=120
)
use_stream_chat = gr.Checkbox(label="use_stream_chat", value=True)
with gr.Row():
addfile_btn = gr.UploadButton(
"π Upload (δΈδΌ ζδ»Ά)", file_types=["image", "video"]
)
submit_btn = gr.Button("π Submit (ει)")
regen_btn = gr.Button("π€οΈ Regenerate (ιθ―)")
empty_bin = gr.Button("π§Ή Clear History (ζΈ
ι€εε²)")
submit_btn.click(
add_text, [chatbot, task_history, query], [chatbot, task_history]
).then(
predict,
[chatbot, task_history, max_tokens, temperature, timeout, use_stream_chat],
[chatbot],
show_progress=True,
)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
regen_btn.click(
regenerate,
[chatbot, task_history, max_tokens, temperature, timeout, use_stream_chat],
[chatbot],
show_progress=True,
)
addfile_btn.upload(
add_file,
[chatbot, task_history, addfile_btn],
[chatbot, task_history],
show_progress=True,
)
demo.queue().launch(
share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,
)
def main():
from openai import OpenAI
args = _get_args()
openai_api_key = args.eas_token
if not args.eas_endpoint.endswith("/"):
args.eas_endpoint += "/"
openai_api_base = f"{args.eas_endpoint}v1"
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
models = client.models.list()
model = models.data[0].id
_launch_ui(model, client, args)
if __name__ == "__main__":
main()