def chat()

in mds_plugin/genai.py [0:0]


def chat(prompt, **kwargs):
    """Processes a chat request and return a generated answer

    If no options are passed, they are generated from the prompt

    Args:
        prompt (str): The question of the user
        **kwargs: Additional options

    Keyword Args:
        options (dict): The options that store information about the request.
        session (object): The database session to use.
        send_gui_message (object): The function to send a message to he GUI.

    Returns:
        A dict with the generated answer and the options used
    """
    from mysqlsh import globals

    session = kwargs.get("session")
    options = kwargs.get("options", None)
    send_gui_message = kwargs.get("send_gui_message")

    model_options = options.get("model_options", {})
    model_language = model_options.get("language", "en")
    model_language_name = languages.GenerativeAILanguage(model_language).name

    # Clean up options
    options.pop("documents", None)
    options.pop("request_completed", None)

    # If the user selected default as the model_id, remove it from the model_options and let the server pick
    if model_options.get("model_id", "") == "default":
        model_options.pop("model_id")
        options["model_options"] = model_options

    # Clear table list if lock_table_list is not set to true
    if options.get("lock_table_list", False) == False:
        options.pop("tables", None)

    if not session:
        session = globals.session
        if not session:
            raise Exception("No database session specified.")

    if send_gui_message is not None:
        send_gui_message("data", {"info": "Checking chat engine status ..."})

    status = get_status(session=session)
    if status.get("heatwave_support") is False and status.get("local_model_support") is False:
        raise Exception("GenAI support is not available. Please connect to a HeatWave 9.0 instance or higher.")

    # Remove language if not supported
    if status.get("language_support") is False and "language" in model_options:
        model_options.pop("language")
        options["model_options"] = model_options

    if status.get("heatwave_support") is True:
        lang_opts = options.pop("language_options", {})
        language = lang_opts.get("language")

        # If a language has been selected for translation, do the translation
        if language is not None and lang_opts.get("translate_user_prompt") is not False and \
            language != model_language_name:
            send_gui_message(
                "data", {"info": f"Translating prompt from {language} to {model_language_name} ..."})

            # Translate the prompt
            prompt = translate_string(
                session, prompt,
                target_language=model_language_name,
                model_id=lang_opts.get("model_id"),
                source_language=language)

        send_gui_message("data", {"info": "Generating answer ..."})

        session.run_sql("SET @chat_options = ?", [json.dumps(options)])

        res = session.run_sql("CALL sys.heatwave_chat(?)", [prompt])

        send_gui_message("data", {"info": "Processing results ..."})
        next_result = True
        while next_result:
            rows = res.fetch_all()

            if len(rows) == 0:
                next_result = res.next_result()
                continue

            cols = res.get_column_names()

            # Either the first result set column is named "chat_options"
            if len(cols) > 0 and cols[0] == "chat_options" and len(rows[0]) > 0:
                options = json.loads(rows[0][0])
                send_gui_message("data", options)
            # Note: For now the last response is ignored since we fetch the @chat_options session var instead
            # or "response" for the final response that contains all tokens at once
            # elif len(cols) > 0 and cols[0] == "response" and len(rows[0]) > 0:
            #     options = { "token": json.loads(rows[0][0]) }
            #     send_gui_message("data", options)

            next_result = res.next_result()

        res = session.run_sql("SELECT @chat_options")
        rows = res.fetch_all()
        if len(rows) > 0:
            options = json.loads(rows[0][0])

            if language is not None and lang_opts.get("translate_response") is not False and \
                language != model_language_name:
                send_gui_message(
                    "data", {"info": f"Translating response from {model_language_name} to {language} ..."})

                # Translate the response
                response = translate_string(
                    session, options.get("response"),
                    target_language=language,
                    model_id=lang_opts.get("model_id"),
                    source_language=model_language_name)
                options["response"] = response

            send_gui_message("data", options)

    else:
        from . import mockchat

        api_key_path = os.path.join(
            get_shell_user_dir(), "plugin_data", "mds_plugin", "cohere_api_key.txt")
        if os.path.exists(api_key_path):
            cohere_api_key = open(api_key_path).read().strip()
            mockchat.set_api_key(cohere_api_key)

        return mockchat.chat(prompt, options, session, send_gui_message)