msm_plugin/lib/core.py (500 lines of code) (raw):
# Copyright (c) 2025, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0,
# as published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms, as
# designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an additional
# permission to link the program and your derivative works with the
# separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See
# the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Sub-Module for core functions"""
# cSpell:ignore mysqlsh, msm
import datetime
from enum import IntEnum
import json
import os
import pathlib
import re
import threading
import traceback
import mysqlsh
import sys
from contextlib import contextmanager
SCHEMA_METADATA_LOCK_ERROR = "Failed to acquire schema metadata lock. Please ensure no other metadata update is running, then try again."
def get_msm_plugin_data_path() -> str:
# Get msm plugin data folder, create if it does not exist yet
msm_plugin_data_path = os.path.abspath(
mysqlsh.plugin_manager.general.get_shell_user_dir(
'plugin_data', 'msm_plugin'))
pathlib.Path(msm_plugin_data_path).mkdir(parents=True, exist_ok=True)
return msm_plugin_data_path
def get_msm_schema_update_log_path() -> str:
return os.path.join(
get_msm_plugin_data_path(), 'msm_schema_update_log.txt')
def write_to_msm_schema_update_log(type, message):
# Create/Open the log file and append the message
with open(get_msm_schema_update_log_path(), "a+") as file:
file.write(f"{datetime.datetime.now()} - {type} - {message}\n")
class ConfigFile:
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ConfigFile, cls).__new__(cls)
return cls.instance
def __init__(self) -> None:
self._settings = {}
self._filename = os.path.join(
get_msm_plugin_data_path(), "config.json")
try:
with open(self._filename, "r") as f:
self._settings = json.load(f)
except:
pass
def store(self):
# create a copy because we're changing the dict data
settings_copy = self._settings.copy()
os.makedirs(os.path.dirname(self._filename), exist_ok=True)
with open(self._filename, "w") as f:
json.dump(self._serialize(settings_copy), f)
@property
def settings(self):
return self._settings
def _serialize(self, value):
if isinstance(value, bytes):
return f"0x{value.hex()}"
if isinstance(value, dict) or "Dict" in type(value).__name__:
result = {}
for key, val in value.items():
result[key] = self._serialize(val)
return result
if isinstance(value, list) or "List" in type(value).__name__:
return [self._serialize(val) for val in value]
return value
def get_working_dir():
return os.path.expanduser(
ConfigFile().settings.get("workingDirectory", "~"))
def get_interactive_default():
"""Returns the default of the interactive mode
Returns:
True if the interactive mode is enabled, False otherwise
"""
if mysqlsh.globals.shell.options.useWizards:
ct = threading.current_thread()
if ct.__class__.__name__ == "_MainThread":
return True
return False
class LogLevel(IntEnum):
NONE = 1
INTERNAL_ERROR = 2
ERROR = 3
WARNING = 4
INFO = 5
DEBUG = 6
DEBUG2 = 7
DEBUG3 = 8
def print_exception(exc_type, exc_value, exc_traceback):
# Exception handler for the MsmDbSession context manager, which should
# be used only in interactive mode.
# Returns True to signal the exception was dealt with
if mysqlsh.globals.shell.options.verbose <= 1:
print(f"{exc_value}")
else:
exc_str = "".join(
[
s.replace("\\n", "\n")
for s in traceback.format_exception(exc_type, exc_value, exc_traceback)
]
)
print(exc_str)
return True
def get_interactive_result():
"""
To be used in plugin functions that may return pretty formatted result when
called in an interactive Shell session
"""
return get_interactive_default()
def script_path(*suffixes):
return os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), *suffixes
)
def get_current_session(session=None):
"""Returns the current database session
If a session is provided, it will be returned instead of the current one.
If there is no active session, then an exception will be raised.
Returns:
The current database session
"""
if session is not None:
return session
# Check if the user provided a session or there is an active global session
session = mysqlsh.globals.shell.get_session()
if session is None or not session.is_open():
raise Exception(
"MySQL session not specified. Please either pass a session "
"object when calling the function or open a database "
"connection in the MySQL Shell first."
)
return session
def get_db_schema_version(session, schema_name):
row = (
select(
table=f"`{schema_name}`.`msm_schema_version`",
cols=[
"major",
"minor",
"patch",
],
)
.exec(session)
.first
)
if not row:
raise Exception(
"Unable to fetch the MSM database schema version.")
return [row["major"], row["minor"], row["patch"]]
def get_db_schema_version_int(session):
row = get_db_schema_version(session)
return row[0] * 10000 + row[1] * 100 + row[2]
def db_schema_exists(session, schema_name):
row = (
MsmDbExec("""
SELECT COUNT(*) > 0 AS schema_exists
FROM INFORMATION_SCHEMA.SCHEMATA
WHERE SCHEMA_NAME = ?
""")
.exec(session, [schema_name]).first
)
return row["schema_exists"]
def prompt_for_list_item(
item_list,
prompt_caption,
prompt_default_value="",
item_name_property=None,
given_value=None,
print_list=False,
allow_multi_select=False,
):
"""Lets the use choose and item from a list
When prompted, the user can either provide the index of the item or the
name of the item.
If given_value is provided, it will be checked against the items in the list
instead of prompting the user for a new value
Args:
item_list (list): The list of items to choose from
prompt_caption (str): The caption of the prompt that will be displayed
prompt_default_value (str): The default_value for the prompt
item_name_property (str): The name of the property that is used to
compare with the user input
given_value (str): Value that the user provided beforehand.
print_list (bool): Specifies whether the list of items should be printed
allow_multi_select (bool): Whether multiple items can be entered,
separated by ',' and the string '*' is allowed
Returns:
The selected item or the selected item list when allow_multi_select is
True or None when the user cancelled the selection
"""
# If a given_value was provided, check this first instead of prompting the
# user
if given_value:
given_value = given_value.lower()
selected_item = None
for item in item_list:
if item_name_property is not None:
if isinstance(item, dict):
item_name = item.get(item_name_property)
else:
item_name = getattr(item, item_name_property)
else:
item_name = item
if item_name.lower() == given_value:
selected_item = item
break
return selected_item
if print_list:
i = 1
for item in item_list:
if item_name_property:
if isinstance(item, dict):
item_caption = item.get(item_name_property)
else:
item_caption = getattr(item, item_name_property)
else:
item_caption = item
print(f"{i:>4} {item_caption}")
i += 1
print()
selected_items = []
# Let the user choose from the list
while len(selected_items) == 0:
# Prompt the user for specifying an item
prompt = (
mysqlsh.globals.shell.prompt(
prompt_caption, {"defaultValue": prompt_default_value}
)
.strip()
.lower()
)
if prompt == "":
return None
# If the user typed '*', return full list
if allow_multi_select and prompt == "*":
return item_list
if allow_multi_select:
prompt_items = prompt.split(",")
else:
prompt_items = [prompt]
try:
for prompt in prompt_items:
try:
# If the user provided an index, try to map that to an item
nr = int(prompt)
if nr > 0 and nr <= len(item_list):
selected_items.append(item_list[nr - 1])
else:
raise IndexError
except ValueError:
# Search by name
selected_item = None
for item in item_list:
if item_name_property is not None:
if isinstance(item, dict):
item_name = item.get(item_name_property)
else:
item_name = getattr(item, item_name_property)
else:
item_name = item
if item_name.lower() == prompt:
selected_item = item
break
if selected_item is None:
raise ValueError
else:
selected_items.append(selected_item)
except (ValueError, IndexError):
msg = f"The item {prompt} was not found. Please try again"
if prompt_default_value == "":
msg += " or leave empty to cancel the operation.\n"
else:
msg += ".\n"
print(msg)
if allow_multi_select:
return selected_items if len(selected_items) > 0 else None
elif len(selected_items) > 0:
return selected_items[0]
def get_sql_result_as_dict_list(res, binary_formatter=None):
"""Returns the result set as a list of dicts
Args:
res: (object): The sql result set
binary_formatter (callback): function receiving binary data and returning formatted value
Returns:
A list of dicts
"""
if not res:
return []
cols = res.get_columns()
rows = res.fetch_all()
dict_list = []
for row in rows:
item = {}
for col in cols:
col_name = col.get_column_label()
field_val = row.get_field(col_name)
# The right way to get the column type is with "get_type().data". Using
# get_type() may return "Constant" or the data type depending if the shell
# is started in with --json or not.
col_type = col.get_type().data
if col_type == "BIT" and col.get_length() == 1:
item[col_name] = field_val == 1
elif col_type == "SET":
item[col_name] = field_val.split(",") if field_val else []
elif col_type == "JSON":
item[col_name] = json.loads(field_val) if field_val else None
elif binary_formatter is not None and isinstance(field_val, bytes):
item[col_name] = binary_formatter(field_val)
else:
item[col_name] = field_val
dict_list.append(item)
return dict_list
def prompt(message, options=None) -> str:
"""Prompts the user for input
Args:
message (str): A string with the message to be shown to the user.
config (dict): Dictionary with options that change the function
behavior. The options dictionary may contain the following options:
- defaultValue: a str value to be returned if the provides no data.
- type: a str value to define the prompt type.
The type option supports the following values:
- password: the user input will not be echoed on the screen.
Returns:
A string value containing the input from the user.
"""
return mysqlsh.globals.shell.prompt(message, options)
def convert_json(value) -> dict:
try:
value_str = json.dumps(value)
except:
value_str = str(value)
value_str = value_str.replace("{'", '{"')
value_str = value_str.replace("'}'", '"}')
value_str = value_str.replace("', '", '", "')
value_str = value_str.replace("': '", '": "')
value_str = value_str.replace("': [", '": [')
value_str = value_str.replace("], '", '], "')
value_str = value_str.replace("['", '["')
value_str = value_str.replace("']", '"]')
value_str = value_str.replace("': ", '": ')
value_str = value_str.replace(", '", ', "')
value_str = value_str.replace(": b'", ': "')
return json.loads(value_str)
class MsmDbExec:
def __init__(self, sql: str, params=[], binary_formatter=None) -> None:
self._sql = sql
self._result = None
self._params = params
self._binary_formatter = binary_formatter
def _convert_to_database(self, var):
if isinstance(var, list):
return ",".join(var)
if isinstance(var, dict):
return json.dumps(dict(var))
return var
@property
def dump(self) -> "MsmDbExec":
print(f"sql: {self._sql}\nparams: {self._params}")
return self
def exec(self, session, params=[]) -> "MsmDbExec":
self._params = self._params + params
try:
# convert lists and dicts to store in the database
self._params = [self._convert_to_database(
param) for param in self._params]
self._result = session.run_sql(self._sql, self._params)
except Exception as e:
mysqlsh.globals.shell.log(
LogLevel.WARNING.name, f"[{e}\nsql: {
self._sql}\nparams: {self._params}"
)
raise
return self
def __str__(self):
return self._sql
@property
def items(self):
return get_sql_result_as_dict_list(self._result, self._binary_formatter)
@property
def first(self):
result = get_sql_result_as_dict_list(
self._result, self._binary_formatter)
if not result:
return None
return result[0]
@property
def success(self):
return self._result.get_affected_items_count() > 0
@property
def id(self):
return self._result.auto_increment_value
@property
def affected_count(self):
return self._result.get_affected_items_count()
def _generate_where(where):
if where:
if isinstance(where, list):
return " WHERE " + " AND ".join(where)
else:
return " WHERE " + where
return ""
def select(
table: str, cols=["*"], where=[], order=None, binary_formatter=None
) -> MsmDbExec:
if not isinstance(cols, str):
cols = ",".join(cols)
if order is not None and not isinstance(order, str):
order = ",".join(order)
sql = f"""
SELECT {cols}
FROM {table}
{_generate_where(where)}"""
if order:
sql = f"{sql} ORDER BY {order}"
return MsmDbExec(sql, binary_formatter=binary_formatter)
def update(table: str, sets, where=[]) -> MsmDbExec:
params = []
if isinstance(sets, list):
sets = ",".join(sets)
elif isinstance(sets, dict):
params = [value for value in sets.values()]
sets = ",".join([f"{key}=?" for key in sets.keys()])
sql = f"""
UPDATE {table}
SET {sets}
{_generate_where(where)}"""
return MsmDbExec(sql, params)
def delete(table: str, where=[]) -> MsmDbExec:
sql = f"""
DELETE FROM {table}
{_generate_where(where)}"""
return MsmDbExec(sql)
def insert(table, values={}):
params = []
place_holders = []
cols = []
if isinstance(values, list):
cols = ",".join(values)
place_holders = ",".join(["?" for val in values])
elif isinstance(values, dict):
cols = ",".join([str(col) for col in values.keys()])
place_holders = ",".join(["?" for val in values.values()])
params = [val for val in values.values()]
sql = f"""
INSERT INTO {table}
({cols})
VALUES
({place_holders})
"""
return MsmDbExec(sql, params)
class MsmDbTransaction:
def __init__(self, session) -> None:
self._session = session
def __enter__(self) -> "MsmDbTransaction":
self._session.run_sql("START TRANSACTION")
return self
def __exit__(self, exc_type, exc_value, exc_traceback) -> bool:
if exc_type is None:
self._session.run_sql("COMMIT")
return
self._session.run_sql("ROLLBACK")
return False
def uppercase_first_char(s):
if len(s) > 0:
return s[0].upper() + s[1:]
return ""
def convert_path_to_camel_case(path):
if path.startswith("/"):
path = path[1:]
parts = path.replace("/", "_").split("_")
s = parts[0] + "".join(uppercase_first_char(x) for x in parts[1:])
# Only return alphanumeric characters
return "".join(e for e in s if e.isalnum())
def convert_path_to_pascal_case(path):
return uppercase_first_char(convert_path_to_camel_case(path))
def convert_snake_to_camel_case(snake_str):
snake_str = "".join(x.capitalize() for x in snake_str.lower().split("_"))
return snake_str[0].lower() + snake_str[1:]
def convert_to_snake_case(str):
return re.sub(r"(?<!^)(?=[A-Z])", "_", str).lower()
def escape_str(s):
return s.replace("\\", "\\\\").replace('"', '\\"').replace("'", "\\'")
def quote_str(s):
return '"' + escape_str(s) + '"'
def unescape_str(s):
return s.replace("\\'", "'").replace('\\"', '"').replace("\\\\", "\\")
def unquote_str(s):
if (s.startswith("'") and s.endswith("'")) or (
s.startswith('"') and s.endswith('"')
):
return unescape_str(s[1:-1])
return s
def quote_ident(s):
return mysqlsh.mysql.quote_identifier(s)
def unquote_ident(s):
return mysqlsh.mysql.unquote_identifier(s)
def convert_version_str_to_list(version: str) -> list[int]:
version_match = re.match(r"(\d+)\.(\d+)\.(\d+)", version)
if version_match is None:
raise ValueError(
"The version needs to be specified using the following format: major.minor.patch")
return [int(version_match.group(1)), int(version_match.group(2)), int(version_match.group(3))]
class MsmDbExec:
def __init__(self, sql: str, params=[], binary_formatter=None) -> None:
self._sql = sql
self._result = None
self._params = params
self._binary_formatter = binary_formatter
def _convert_to_database(self, var):
if isinstance(var, list):
return ",".join(var)
if isinstance(var, dict):
return json.dumps(dict(var))
return var
@property
def dump(self) -> "MsmDbExec":
print(f"sql: {self._sql}\nparams: {self._params}")
return self
def exec(self, session, params=[]) -> "MsmDbExec":
self._params = self._params + params
try:
# convert lists and dicts to store in the database
self._params = [self._convert_to_database(
param) for param in self._params]
self._result = session.run_sql(self._sql, self._params)
except Exception as e:
mysqlsh.globals.shell.log(
LogLevel.WARNING.name, f"[{e}\nsql: {self._sql}\nparams: {self._params}"
)
raise
return self
def __str__(self):
return self._sql
@property
def items(self):
return get_sql_result_as_dict_list(self._result, self._binary_formatter)
@property
def first(self):
result = get_sql_result_as_dict_list(
self._result, self._binary_formatter)
if not result:
return None
return result[0]
@property
def success(self):
return self._result.get_affected_items_count() > 0
@property
def id(self):
return self._result.auto_increment_value
@property
def affected_count(self):
return self._result.get_affected_items_count()
def execute_msm_sql_script(
session, sql_script: str = None, script_name: str = None,
sql_file_path: str = None):
if sql_script is None and sql_file_path is None:
raise ValueError("No script or sql_file_path specified.")
if sql_script is not None and script_name is None:
raise ValueError("No script_name specified.")
if script_name is None and sql_file_path is not None:
script_name = sql_file_path
write_to_msm_schema_update_log(
"INFO", f"Running SQL script `{script_name}` ...")
if sql_file_path is not None:
with open(sql_file_path) as f:
sql_script = f.read()
commands = mysqlsh.mysql.split_script(sql_script)
msm_lock = 0
try:
# Acquire MSM_METADATA_LOCK
msm_lock = (
MsmDbExec('SELECT GET_LOCK("MSM_METADATA_LOCK", 1) AS msm_lock')
.exec(session)
.first["msm_lock"]
)
if msm_lock == 0:
raise Exception(
"Failed to acquire MSM schema update lock. Please ensure no "
"other MSM schema update is running, then try again.")
# Execute all commands
current_cmd = ""
try:
for cmd in commands:
current_cmd = cmd.strip()
if current_cmd:
session.run_sql(current_cmd)
write_to_msm_schema_update_log(
"INFO", f"SQL script {script_name} executed successfully.")
except mysqlsh.DBError as e:
# On exception, drop the schema and re-raise
write_to_msm_schema_update_log(
"ERROR", f"Failed to run the the SQL script `{script_name}`.\n{current_cmd}\n{e}")
raise Exception(
f"Failed to run the SQL script.\n{current_cmd}\n{e}"
)
finally:
if msm_lock == 1:
MsmDbExec('SELECT RELEASE_LOCK("MSM_METADATA_LOCK")').exec(session)