in mds_plugin/genai.py [0:0]
def chat(prompt, **kwargs):
"""Processes a chat request and return a generated answer
If no options are passed, they are generated from the prompt
Args:
prompt (str): The question of the user
**kwargs: Additional options
Keyword Args:
options (dict): The options that store information about the request.
session (object): The database session to use.
send_gui_message (object): The function to send a message to he GUI.
Returns:
A dict with the generated answer and the options used
"""
from mysqlsh import globals
session = kwargs.get("session")
options = kwargs.get("options", None)
send_gui_message = kwargs.get("send_gui_message")
model_options = options.get("model_options", {})
model_language = model_options.get("language", "en")
model_language_name = languages.GenerativeAILanguage(model_language).name
# Clean up options
options.pop("documents", None)
options.pop("request_completed", None)
# If the user selected default as the model_id, remove it from the model_options and let the server pick
if model_options.get("model_id", "") == "default":
model_options.pop("model_id")
options["model_options"] = model_options
# Clear table list if lock_table_list is not set to true
if options.get("lock_table_list", False) == False:
options.pop("tables", None)
if not session:
session = globals.session
if not session:
raise Exception("No database session specified.")
if send_gui_message is not None:
send_gui_message("data", {"info": "Checking chat engine status ..."})
status = get_status(session=session)
if status.get("heatwave_support") is False and status.get("local_model_support") is False:
raise Exception("GenAI support is not available. Please connect to a HeatWave 9.0 instance or higher.")
# Remove language if not supported
if status.get("language_support") is False and "language" in model_options:
model_options.pop("language")
options["model_options"] = model_options
if status.get("heatwave_support") is True:
lang_opts = options.pop("language_options", {})
language = lang_opts.get("language")
# If a language has been selected for translation, do the translation
if language is not None and lang_opts.get("translate_user_prompt") is not False and \
language != model_language_name:
send_gui_message(
"data", {"info": f"Translating prompt from {language} to {model_language_name} ..."})
# Translate the prompt
prompt = translate_string(
session, prompt,
target_language=model_language_name,
model_id=lang_opts.get("model_id"),
source_language=language)
send_gui_message("data", {"info": "Generating answer ..."})
session.run_sql("SET @chat_options = ?", [json.dumps(options)])
res = session.run_sql("CALL sys.heatwave_chat(?)", [prompt])
send_gui_message("data", {"info": "Processing results ..."})
next_result = True
while next_result:
rows = res.fetch_all()
if len(rows) == 0:
next_result = res.next_result()
continue
cols = res.get_column_names()
# Either the first result set column is named "chat_options"
if len(cols) > 0 and cols[0] == "chat_options" and len(rows[0]) > 0:
options = json.loads(rows[0][0])
send_gui_message("data", options)
# Note: For now the last response is ignored since we fetch the @chat_options session var instead
# or "response" for the final response that contains all tokens at once
# elif len(cols) > 0 and cols[0] == "response" and len(rows[0]) > 0:
# options = { "token": json.loads(rows[0][0]) }
# send_gui_message("data", options)
next_result = res.next_result()
res = session.run_sql("SELECT @chat_options")
rows = res.fetch_all()
if len(rows) > 0:
options = json.loads(rows[0][0])
if language is not None and lang_opts.get("translate_response") is not False and \
language != model_language_name:
send_gui_message(
"data", {"info": f"Translating response from {model_language_name} to {language} ..."})
# Translate the response
response = translate_string(
session, options.get("response"),
target_language=language,
model_id=lang_opts.get("model_id"),
source_language=model_language_name)
options["response"] = response
send_gui_message("data", options)
else:
from . import mockchat
api_key_path = os.path.join(
get_shell_user_dir(), "plugin_data", "mds_plugin", "cohere_api_key.txt")
if os.path.exists(api_key_path):
cohere_api_key = open(api_key_path).read().strip()
mockchat.set_api_key(cohere_api_key)
return mockchat.chat(prompt, options, session, send_gui_message)