mds_plugin/mockchat.py (277 lines of code) (raw):
# Copyright (c) 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0,
# as published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms, as
# designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an additional
# permission to link the program and your derivative works with the
# separately licensed software that they have included with MySQL.
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See
# the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
import numpy as np
import pickle
import json
from typing import Callable, Tuple
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import cohere
import threading
import mysqlsh
DATA_TYPE_INFO = "info"
DATA_TYPE_TOKEN = "token"
DATA_TYPE_OPTIONS = "options"
def interactive_mode_set():
"""Checks the current status of interactive mode
Returns:
True if the MySQL Shell is running in interactive mode
"""
if mysqlsh.globals.shell.options.useWizards:
ct = threading.current_thread()
if ct.__class__.__name__ == '_MainThread':
return True
return False
g_cohere_api_key = None
g_chat = None
g_embedding_model = None
# template types:
# - find documents
# - with filter
# - with join
# - apply task to each
# - summarize
# - q&a
#
# - find documents
#
#
# Practical questions:
# - max size of context?
# - limit size of context for queries?
# - huge context (even if accidental) -> huge $$$
#
#
class DBInterface:
def __init__(self, session) -> None:
self.session = session
class Template:
source_tables = []
maximum_distance: float
default_limit: int
context_table = "mysqlsh.context"
template_params = []
def __init__(self) -> None:
pass
def _query_topk(self, session, k, max_dist):
raise NotImplemented()
def make_create_context_table(self) -> Tuple[str, list]:
raise NotImplemented()
def get_search_queries(self, query: str, options: dict) -> list:
raise NotImplemented()
def make_select(self, text: str, params: dict) -> Tuple[str, list, str]:
raise NotImplemented()
def make_inserter(self, query_emb) -> Callable:
raise NotImplemented()
def make_generator(self, text: str, options: dict) -> Callable:
raise NotImplemented()
def format_result(self, result) -> str:
return result
class GenericDocumentTableTemplate(Template):
def __init__(self) -> None:
super().__init__()
self.query = ""
self.template_params = [""]
self.table_columns = [
"document_name",
"segment",
"metadata",
"segment_embedding",
]
self._model = g_embedding_model
def make_create_context_table(self):
return f"create temporary table {self.context_table} (id varchar(256), dist double, metadata json, segment longtext)"
def _query_topk(self, session, k, max_dist):
return session.run_sql(
f"select id, dist, segment from {self.context_table} order by dist asc limit ?",
[k],
)
def make_select(self, text: str, params: dict) -> Tuple[str, list, str]:
columns = ", ".join(self.table_columns)
query = " UNION ".join(
[f"""SELECT {columns} FROM `{table.schema_name}`.`{table.table_name}`""" for table in self.source_tables]
)
return (
query,
[],
# NOTE: in the real version, query_emb should be kept as a uservar and doesn't need to be fetched and passed around
self._model.encode(text),
)
def make_inserter(self, query_emb):
def insert_one(session, row):
row_embedding = pickle.loads(row[-1])
distance = np.linalg.norm(query_emb - row_embedding)
session.run_sql(
f"""INSERT INTO {self.context_table} (id, dist, segment) VALUES (?, ?, ?)""",
[row[0], float(distance), row[1]],
)
return insert_one
class CohereTemplate(GenericDocumentTableTemplate):
def __init__(self, api_key):
super().__init__()
self.co = cohere.Client(api_key)
def get_search_queries(self, query: str, options: dict):
return self.co.chat(
message=query,
model="command",
preamble_override=options.get("preamble") if options is not None else "",
search_queries_only=True,
).search_queries
def make_generator(self, query: str):
def generate(session, hint: str, context: list, options: dict):
return self.co.chat(
message=query, model="command", documents=context,
preamble_override=options.get("preamble") if options is not None else "",
stream=True
)
return generate
def format_result(self, result) -> str:
return result.text
class ContextManager:
def __init__(self, session, template: Template):
self.session = session
self.template = template
def __build_context_table(self, query: str, args: list, query_emb):
self.session.run_sql(self.template.make_create_context_table())
res = self.session.run_sql(query, args).fetch_all()
inserter = self.template.make_inserter(query_emb)
for row in res:
inserter(self.session, row)
def search(self, text: str, params: dict = {}):
self.reset()
query, args, query_emb = self.template.make_select(text, params)
self.__build_context_table(query, args, query_emb)
return self.template._query_topk(
self.session, self.template.default_limit, self.template.maximum_distance
)
def refine(self, text: str):
pass
def reset(self):
self.session.run_sql(
f"drop table if exists {self.template.context_table}")
self.session.run_sql(
f'create schema if not exists {self.template.context_table.split(".")[0]}',
)
class TaskManager:
def __init__(self, session, template) -> None:
self.session = session
self.template = template
self._query_embedding = None
def generate(self, query: str, hint: str, context: list, options: dict):
gen = self.template.make_generator(query)
return gen(self.session, hint, context, options)
def format(self, result):
return self.template.format_result(result)
class Chat:
debug = 0
report_status = 0
streaming = 0
re_run = 0
def __init__(self, session, templ: Template, send_gui_message) -> None:
# assert session.uri.startswith(
# "mysql://"
# ), "Connection must use classic protocol"
self.templ = templ
self.session = session
self.send_gui_message = send_gui_message
self.context = ContextManager(session, templ)
self.task = TaskManager(session, templ)
def dump(self, query: str):
return self.context.search(query)
def send_data(self, data, type, end="\n"):
if interactive_mode_set() or self.send_gui_message is None:
print(data, end=end)
else:
if type == DATA_TYPE_INFO:
data_to_send = {"info": data}
elif type == DATA_TYPE_TOKEN:
data_to_send = {"token": data}
else:
data_to_send = data
self.send_gui_message("data", data_to_send)
def run(self, query: str, options: dict):
self._apply_options(options)
if self.debug:
print("Query: ", query, f"({self.templ.__class__.__name__})")
r = self.run_(query, options)
if self.debug:
print("Result:", r)
print()
return r
def run_(self, query: str, options: dict):
if self.debug or self.report_status:
self.send_data("Decomposing prompt ...", DATA_TYPE_INFO)
search_queries = self.templ.get_search_queries(query, options)
if self.debug:
print("\tqueries=", search_queries)
if not search_queries:
queries = [query]
else:
queries = [q["text"] for q in search_queries]
context = []
for q in queries:
if self.debug:
print("* searching matches for:", q)
context += self.context.search(q).fetch_all()
docs = [{"id": row[0], "snippet": row[2]} for row in context]
if self.debug:
print("* context:")
for row in context:
print("\t", row[0], row[1], row[2]
[:80].replace("\n", "\n\t") + "...")
options["documents"] = [
{"id": row[0], "title": row[0], "snippet": row[2][:80], "pinned": False} for row in context]
# Send the documents as soon as available, when streaming is enabled
if self.streaming:
if self.debug or self.report_status:
self.send_data({"documents": options["documents"], "info": "Generating answer ..."}, DATA_TYPE_OPTIONS)
else:
self.send_data({"documents": options["documents"]}, DATA_TYPE_OPTIONS)
result = ""
if not self.skip_generate:
response = self.task.generate(query, "", docs, options)
for event in response:
if event.event_type == "text-generation":
result += event.text
self.send_data(event.text, DATA_TYPE_TOKEN, end="")
if self.streaming:
response.text = ""
else:
response = None
options["request_completed"] = True
return { "data": self._make_response(response, options) }
def _apply_options(self, options: dict):
self.templ.default_limit = options.get(
"maximum_document_segment_count", 3)
self.templ.maximum_distance = options.get("maximum_distance", 0.3)
options["tables"] = options.get("tables", None) or self._scan_tables()
self.templ.source_tables = options["tables"]
# Send the tables as soon as available, when streaming is used
if self.streaming:
self.send_data({
"tables": options["tables"],
"info": f'{"No " if len(options["tables"]) == 0 else ""}Vector tables found.'
}, DATA_TYPE_OPTIONS)
self.debug = options.get("debug", False)
self.report_status = options.get("report_status", False)
self.skip_generate = options.get("skip_generate", False)
self.streaming = options.get("stream", False)
self.re_run = options.get("re_run", False)
def _make_response(self, result, options: dict):
if result:
options["response"] = self.task.format(result)
options["cohere_meta"] = result.meta
options["cohere_citations"] = result.citations
else:
options["response"] = None
options["cohere_meta"] = None
options["cohere_citations"] = None
return options
def _scan_tables(self):
rows = self.session.run_sql(
"""
SELECT concat(sys.quote_identifier(table_schema), '.', sys.quote_identifier(table_name)),
table_schema as schema_name,
table_name as table_name,
create_options
FROM information_schema.tables WHERE CREATE_OPTIONS LIKE '%SECONDARY_ENGINE%'
"""
).fetch_all()
tables = []
for row in rows:
tables.append({
"schema_name": row[1],
"table_name": row[2],
"vector_embeddings": True
})
if not tables:
raise Exception("No LakeHouse tables found")
if self.debug:
print("* tables to be scanned:", tables)
return tables
def set_api_key(key: str):
global g_cohere_api_key
global g_chat
global g_embedding_model
g_cohere_api_key = key
g_chat = None
if not g_embedding_model:
g_embedding_model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L12-v2"
)
def chat(prompt, options: dict = {}, session=None, send_gui_message=None):
global g_chat
global g_cohere_api_key
global g_embedding_model
if not g_embedding_model:
g_embedding_model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L12-v2"
)
if not g_chat:
g_chat = Chat(session, CohereTemplate(
g_cohere_api_key), send_gui_message)
return g_chat.run(prompt, options or {})
# import mysqlsh
# options = {"debug": 1}
# print(chat("how to disable redo log", options, session=mysqlsh.globals.session))