mrs_plugin/lib/core.py (804 lines of code) (raw):

# Copyright (c) 2021, 2025, Oracle and/or its affiliates. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, # as published by the Free Software Foundation. # # This program is designed to work with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, as # designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an additional # permission to link the program and your derivative works with the # separately licensed software that they have either included with # the program or referenced in the documentation. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See # the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """Sub-Module for core functions""" # cSpell:ignore mysqlsh, mrs import traceback from mrs_plugin.lib import general import mysqlsh import os import re import json from enum import IntEnum import threading import base64 from typing import Optional MRS_METADATA_LOCK_ERROR = "Failed to acquire MRS metadata lock. Please ensure no other metadata update is running, then try again." class ConfigFile: def __init__(self) -> None: self._settings = {} self._filename = os.path.abspath( mysqlsh.plugin_manager.general.get_shell_user_dir( "plugin_data", "mrs_plugin", "config.json" ) ) try: with open(self._filename, "r") as f: self._settings = json.load(f) for item in self._settings.get("current_objects", []): convert_ids_to_binary(["current_service_id"], item) except: pass def store(self): # create a copy because we're changing the dict data settings_copy = self._settings.copy() os.makedirs(os.path.dirname(self._filename), exist_ok=True) with open(self._filename, "w") as f: json.dump(self._serialize(settings_copy), f) @property def settings(self): return self._settings def _serialize(self, value): if isinstance(value, bytes): return f"0x{value.hex()}" if isinstance(value, dict) or "Dict" in type(value).__name__: result = {} for key, val in value.items(): result[key] = self._serialize(val) return result if isinstance(value, list) or "List" in type(value).__name__: return [self._serialize(val) for val in value] return value class Validations: @staticmethod def request_path(value, required=False, session=None): if required and value is None: raise Exception("The request_path is missing.") if value is None: return if not isinstance(value, str) or not value.startswith("/"): raise Exception("The request_path has to start with '/'.") class LogLevel(IntEnum): NONE = 1 INTERNAL_ERROR = 2 ERROR = 3 WARNING = 4 INFO = 5 DEBUG = 6 DEBUG2 = 7 DEBUG3 = 8 def get_local_config(): return ConfigFile().settings def script_path(*suffixes): return os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), *suffixes ) def print_exception(exc_type, exc_value, exc_traceback): # Exception handler for the MrsDbSession context manager, which should # be used only in interactive mode. # Returns True to signal the exception was dealt with if mysqlsh.globals.shell.options.verbose <= 1: print(f"{exc_value}") else: exc_str = "".join( [ s.replace("\\n", "\n") for s in traceback.format_exception(exc_type, exc_value, exc_traceback) ] ) print(exc_str) return True def set_current_objects( service_id: bytes = None, service=None, schema_id: bytes = None, schema=None, content_set_id: bytes = None, content_set=None, db_object_id: bytes = None, db_object=None, ): """Sets the current objects to the given ones Note that if no service or no schema or no db_object are specified, they are reset Args: service_id: The id of the service to set as the current service service (dict): The service to set as the current service schema_id: The id of the schema to set as the current schema schema (dict): The schema to set as the current schema content_set_id: The id of the content_set to set as the current content_set (dict): The content_set to set as the current db_object_id: The id of the db_object to set as the current db_object db_object (dict): The db_object to set as the current db_object Returns: The current or default service or None if no default is set """ # Get current_service_id from the global mrs_config mrs_config = get_current_config() if service_id: mrs_config["current_service_id"] = service_id if service: mrs_config["current_service_id"] = service.get("id") # If current_service_id is current set but not passed in, clear it if mrs_config.get("current_service_id") and not (service_id or service): mrs_config["current_service_id"] = None if schema_id: mrs_config["current_schema_id"] = schema_id if schema: mrs_config["current_schema_id"] = schema.get("id") # If current_schema_id is current set but not passed in, clear it if mrs_config.get("current_schema_id") and not (schema_id or schema): mrs_config["current_schema_id"] = None if db_object_id: mrs_config["current_db_object_id"] = db_object_id if db_object: mrs_config["current_db_object_id"] = db_object.get("id") # If current_db_object_id is current set but not passed in, clear it if mrs_config.get("current_db_object_id") and not (db_object_id or db_object): mrs_config["current_db_object_id"] = None if content_set_id: mrs_config["current_content_set_id"] = content_set_id if content_set: mrs_config["current_content_set_id"] = content_set.get("id") # If current_db_object_id is current set but not passed in, clear it if mrs_config.get("current_content_set_id") and not (content_set_id or content_set): mrs_config["current_content_set_id"] = None def get_interactive_default(): """Returns the default of the interactive mode Returns: The current database session """ if mysqlsh.globals.shell.options.useWizards: ct = threading.current_thread() if ct.__class__.__name__ == "_MainThread": return True return False def get_interactive_result(): """ To be used in plugin functions that may return pretty formatted result when called in an interactive Shell session """ return get_interactive_default() def get_current_session(session=None): """Returns the current database session If a session is provided, it will be returned instead of the current one. If there is no active session, then an exception will be raised. Returns: The current database session """ if session is not None: return session # Check if the user provided a session or there is an active global session session = mysqlsh.globals.shell.get_session() if session is None or not session.is_open(): raise Exception( "MySQL session not specified. Please either pass a session " "object when calling the function or open a database " "connection in the MySQL Shell first." ) return session def get_mrs_schema_version(session): try: # As of MRS metadata schema version 4.0.0 the version VIEW has been # renamed to msm_schema_version, so try that one first row = ( select( table="msm_schema_version", cols=[ "major", "minor", "patch", "CONCAT(major, '.', minor, '.', patch) AS version", ], ) .exec(session) .first ) except: row = ( select( table="schema_version", cols=[ "major", "minor", "patch", "CONCAT(major, '.', minor, '.', patch) AS version", ], ) .exec(session) .first ) if not row: raise Exception("Unable to fetch MRS metadata database schema version.") return [row["major"], row["minor"], row["patch"]] def get_mrs_schema_version_int(session): row = get_mrs_schema_version(session) return row[0] * 10000 + row[1] * 100 + row[2] def mrs_metadata_schema_exists(session): row = ( MrsDbExec( """ SELECT COUNT(*) AS schema_exists FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = 'mysql_rest_service_metadata' """ ) .exec(session) .first ) return row["schema_exists"] def get_mrs_enabled(session): try: row = ( select( table="config", cols=[ "service_enabled", ], ) .exec(session) .first ) if not row: return False return int(row["service_enabled"]) == 1 except: return False def prompt_for_list_item( item_list, prompt_caption, prompt_default_value="", item_name_property=None, given_value=None, print_list=False, allow_multi_select=False, ): """Lets the use choose and item from a list When prompted, the user can either provide the index of the item or the name of the item. If given_value is provided, it will be checked against the items in the list instead of prompting the user for a new value Args: item_list (list): The list of items to choose from prompt_caption (str): The caption of the prompt that will be displayed prompt_default_value (str): The default_value for the prompt item_name_property (str): The name of the property that is used to compare with the user input given_value (str): Value that the user provided beforehand. print_list (bool): Specifies whether the list of items should be printed allow_multi_select (bool): Whether multiple items can be entered, separated by ',' and the string '*' is allowed Returns: The selected item or the selected item list when allow_multi_select is True or None when the user cancelled the selection """ # If a given_value was provided, check this first instead of prompting the # user if given_value: given_value = given_value.lower() selected_item = None for item in item_list: if item_name_property is not None: if isinstance(item, dict): item_name = item.get(item_name_property) else: item_name = getattr(item, item_name_property) else: item_name = item if item_name.lower() == given_value: selected_item = item break return selected_item if print_list: i = 1 for item in item_list: if item_name_property: if isinstance(item, dict): item_caption = item.get(item_name_property) else: item_caption = getattr(item, item_name_property) else: item_caption = item print(f"{i:>4} {item_caption}") i += 1 print() selected_items = [] # Let the user choose from the list while len(selected_items) == 0: # Prompt the user for specifying an item prompt = ( mysqlsh.globals.shell.prompt( prompt_caption, {"defaultValue": prompt_default_value} ) .strip() .lower() ) if prompt == "": return None # If the user typed '*', return full list if allow_multi_select and prompt == "*": return item_list if allow_multi_select: prompt_items = prompt.split(",") else: prompt_items = [prompt] try: for prompt in prompt_items: try: # If the user provided an index, try to map that to an item nr = int(prompt) if nr > 0 and nr <= len(item_list): selected_items.append(item_list[nr - 1]) else: raise IndexError except ValueError: # Search by name selected_item = None for item in item_list: if item_name_property is not None: if isinstance(item, dict): item_name = item.get(item_name_property) else: item_name = getattr(item, item_name_property) else: item_name = item if item_name.lower() == prompt: selected_item = item break if selected_item is None: raise ValueError else: selected_items.append(selected_item) except (ValueError, IndexError): msg = f"The item {prompt} was not found. Please try again" if prompt_default_value == "": msg += " or leave empty to cancel the operation.\n" else: msg += ".\n" print(msg) if allow_multi_select: return selected_items if len(selected_items) > 0 else None elif len(selected_items) > 0: return selected_items[0] def prompt_for_comments(): """Prompts the user for comments Returns: The comments as str """ return prompt("Comments: ").strip() def create_json_binary_decoder(binary_formatter=None): binary_label = "base64:" def decoder(obj): for key in obj: if isinstance(obj[key], str) and obj[key].startswith(binary_label): obj[key] = binary_formatter(obj[key].lstrip(binary_label)) return obj return decoder def get_sql_result_as_dict_list(res, binary_formatter=None): """Returns the result set as a list of dicts Args: res: (object): The sql result set binary_formatter (callback): function receiving binary data and returning formatted value Returns: A list of dicts """ if not res: return [] cols = res.get_columns() rows = res.fetch_all() dict_list = [] for row in rows: item = {} for col in cols: col_name = col.get_column_label() field_val = row.get_field(col_name) # The right way to get the column type is with "get_type().data". Using # get_type() may return "Constant" or the data type depending if the shell # is started in with --json or not. col_type = col.get_type().data if col_type == "BIT" and col.get_length() == 1: item[col_name] = field_val == 1 elif col_type == "SET": item[col_name] = field_val.split(",") if field_val else [] elif col_type == "JSON": item[col_name] = ( json.loads( field_val, object_hook=create_json_binary_decoder(binary_formatter), ) if field_val else None ) elif binary_formatter is not None and isinstance(field_val, bytes): item[col_name] = binary_formatter(field_val) else: item[col_name] = field_val dict_list.append(item) return dict_list def get_current_config(mrs_config=None): """Gets the active config dict If no config dict is given as parameter, the global config dict will be used Args: config (dict): The config to be used or None Returns: The active config dict """ if mrs_config is None: # Check if global object 'mrs_config' has already been registered if "mrs_config" in dir(mysqlsh.globals): mrs_config = getattr(mysqlsh.globals, "mrs_config") else: mrs_config = {} setattr(mysqlsh.globals, "mrs_config", mrs_config) return mrs_config def prompt(message, options=None) -> str: """Prompts the user for input Args: message (str): A string with the message to be shown to the user. config (dict): Dictionary with options that change the function behavior. The options dictionary may contain the following options: - defaultValue: a str value to be returned if the provides no data. - type: a str value to define the prompt type. The type option supports the following values: - password: the user input will not be echoed on the screen. Returns: A string value containing the input from the user. """ return mysqlsh.globals.shell.prompt(message, options) def check_request_path(session, request_path): """Checks if the given request_path is valid and unique Args: request_path (str): The request_path to check **kwargs: Additional options Keyword Args: session (object): The database session to use Returns: None """ if not request_path: raise Exception("No request_path specified.") # Check if the request_path already exists for another db_object of that # schema res = session.run_sql( """ SELECT CONCAT(COALESCE(se.in_development->>'$.developers', ''), h.name, se.url_context_root) as full_request_path FROM `mysql_rest_service_metadata`.service se LEFT JOIN `mysql_rest_service_metadata`.url_host h ON se.url_host_id = h.id WHERE CONCAT(h.name, se.url_context_root) = ? UNION SELECT CONCAT(COALESCE(se.in_development->>'$.developers', ''), h.name, se.url_context_root, sc.request_path) as full_request_path FROM `mysql_rest_service_metadata`.db_schema sc LEFT OUTER JOIN `mysql_rest_service_metadata`.service se ON se.id = sc.service_id LEFT JOIN `mysql_rest_service_metadata`.url_host h ON se.url_host_id = h.id WHERE CONCAT(h.name, se.url_context_root, sc.request_path) = ? UNION SELECT CONCAT(COALESCE(se.in_development->>'$.developers', ''), h.name, se.url_context_root, sc.request_path, o.request_path) as full_request_path FROM `mysql_rest_service_metadata`.db_object o LEFT OUTER JOIN `mysql_rest_service_metadata`.db_schema sc ON sc.id = o.db_schema_id LEFT OUTER JOIN `mysql_rest_service_metadata`.service se ON se.id = sc.service_id LEFT JOIN `mysql_rest_service_metadata`.url_host h ON se.url_host_id = h.id WHERE CONCAT(h.name, se.url_context_root, sc.request_path, o.request_path) = ? UNION SELECT CONCAT(COALESCE(se.in_development->>'$.developers', ''), h.name, se.url_context_root, co.request_path) as full_request_path FROM `mysql_rest_service_metadata`.content_set co LEFT OUTER JOIN `mysql_rest_service_metadata`.service se ON se.id = co.service_id LEFT JOIN `mysql_rest_service_metadata`.url_host h ON se.url_host_id = h.id WHERE CONCAT(h.name, se.url_context_root, co.request_path) = ? """, [request_path, request_path, request_path, request_path], ) row = res.fetch_one() if row and row.get_field("full_request_path") != "": raise Exception(f"The request_path {request_path} is already " "in use.") def check_mrs_object_name(session, db_schema_id, obj_id, obj_name): """Checks if the given mrs object name is valid and unique""" res = session.run_sql( """ SELECT o.name FROM mysql_rest_service_metadata.object o LEFT JOIN mysql_rest_service_metadata.db_object dbo ON o.db_object_id = dbo.id WHERE dbo.db_schema_id = ? AND UPPER(o.name) = UPPER(?) AND o.id <> ? """, [ id_to_binary(db_schema_id, "db_schema_id"), obj_name, id_to_binary(obj_id, "object.id"), ], ) row = res.fetch_one() if row and row.get_field("name") != "": raise Exception( f"The object name {obj_name} is already " "in use on this REST schema." ) def check_mrs_object_names(session, db_schema_id, objects): """Checks if the given mrs object names are valid and unique""" if objects is None: return assigned_names = [] for obj in objects: if obj.get("name") in assigned_names: raise Exception( f'The object name {obj.get("name")} has been used more than once.' ) check_mrs_object_name( session=session, db_schema_id=db_schema_id, obj_id=obj.get("id"), obj_name=obj.get("name"), ) assigned_names.append(obj.get("name")) def convert_json(value) -> dict: try: value_str = json.dumps(value) except: value_str = str(value) value_str = value_str.replace("{'", '{"') value_str = value_str.replace("'}'", '"}') value_str = value_str.replace("', '", '", "') value_str = value_str.replace("': '", '": "') value_str = value_str.replace("': [", '": [') value_str = value_str.replace("], '", '], "') value_str = value_str.replace("['", '["') value_str = value_str.replace("']", '"]') value_str = value_str.replace("': ", '": ') value_str = value_str.replace(", '", ', "') value_str = value_str.replace(": b'", ': "') return json.loads(value_str) def cut_last_comma(fields): # Cut the last , away if present if fields.endswith(",\n"): return fields[:-2] # Otherwise, just cut the last \n return fields[:-1] def format_json_entry(key: str, value: dict, advance: int=1): if value is None or value == "": return "" result: str = json.dumps(value, indent=4) # Indent the json.dumps with (advance * 4) spaces # js_indented = "" # advance_space = " " * advance result = [" " * advance + line for line in result.splitlines()] return f" {key} {"\n".join(result).lstrip()}" # for ln in js.split("\n"): # js_indented += f"{advance_space}{ln}\n" # return f" {key} {js_indented[4:-1]}" def id_to_binary(id: str, context: str, allowNone=False): if allowNone and id is None: return None if isinstance(id, bytes): return id elif isinstance(id, str): if id.startswith("0x"): try: result = bytes.fromhex(id[2:]) except Exception: raise RuntimeError(f"Invalid hexadecimal string '{id}' for '{context}'.") elif id.endswith("=="): try: result = base64.b64decode(id, validate=True) except Exception: raise RuntimeError(f"Invalid base64 string '{id}' for '{context}'.") else: raise RuntimeError(f"Invalid id format '{id}' for '{context}'.") if len(result) != 16: raise RuntimeError(f"The '{context}' has an invalid size.") return result raise RuntimeError(f"Invalid id type for '{context}'.") def convert_id_to_base64_string(id) -> str: return base64.b64encode(id).decode('ascii') def convert_ids_to_binary(id_options, kwargs): for id_option in id_options: id = kwargs.get(id_option) if id is not None: kwargs[id_option] = id_to_binary(id, id_option) def try_convert_ids_to_binary(id_options, kwargs): """ Try to convert the kwargs ID entries, but don't fail if it's an invalid ID type. The entry may or may not be an ID, but it needs not to fail if it's not. An use case for this is when we want a parameter, for example 'service', that can be one of the following: - 'localhost@myService' - '0x11EF8496143CFDEC969C7413EA499D96' - 'Ee+ElhQ8/eyWnHQT6kmdlg==' """ for id_option in id_options: id = kwargs.get(id_option) if id is not None: try: kwargs[id_option] = id_to_binary(id, id_option) except RuntimeError as e: if str(e) in [f"Invalid id type for '{id_option}'.", f"Invalid id format '{kwargs[id_option]}' for '{id_option}'."]: continue raise def convert_id_to_string(id) -> str: return f"0x{id.hex()}" def convert_dict_to_json_string(dic) -> str: if dic is None: return None return json.dumps(dict(dic)) def _generate_where(where): if where: if isinstance(where, list): return " WHERE " + " AND ".join(where) else: return " WHERE " + where return "" def _generate_table(table): if "." in table: return table return f"`mysql_rest_service_metadata`.`{table}`" def _generate_qualified_name(name): if "." in name: return name parts = name.split("(") result = f"`mysql_rest_service_metadata`.`{parts[0]}`" if len(parts) == 2: # it's a function call so add the parameters result = f"{result}({parts[1]}" return result class MrsDbExec: def __init__(self, sql: str, params=[], binary_formatter=None) -> None: self._sql = sql self._result = None self._params = params self._binary_formatter = binary_formatter def _convert_to_database(self, var): if isinstance(var, list): return ",".join(var) if isinstance(var, dict): return json.dumps(dict(var)) return var @property def dump(self) -> "MrsDbExec": print(f"sql: {self._sql}\nparams: {self._params}") return self def exec(self, session, params=[]) -> "MrsDbExec": self._params = self._params + params try: # convert lists and dicts to store in the database self._params = [self._convert_to_database(param) for param in self._params] self._result = session.run_sql(self._sql, self._params) except Exception as e: mysqlsh.globals.shell.log( LogLevel.WARNING.name, f"[{e}\nsql: {self._sql}\nparams: {self._params}" ) raise return self def __str__(self): return self._sql @property def items(self): return get_sql_result_as_dict_list(self._result, self._binary_formatter) @property def first(self): result = get_sql_result_as_dict_list(self._result, self._binary_formatter) if not result: return None return result[0] @property def success(self): return self._result.get_affected_items_count() > 0 @property def id(self): return self._result.auto_increment_value @property def affected_count(self): return self._result.get_affected_items_count() def select( table: str, cols=["*"], where=[], order=None, binary_formatter=None ) -> MrsDbExec: if not isinstance(cols, str): cols = ",".join(cols) if order is not None and not isinstance(order, str): order = ",".join(order) sql = f""" SELECT {cols} FROM {_generate_table(table)} {_generate_where(where)}""" if order: sql = f"{sql} ORDER BY {order}" return MrsDbExec(sql, binary_formatter=binary_formatter) def update(table: str, sets, where=[]) -> MrsDbExec: params = [] if isinstance(sets, list): sets = ",".join(sets) elif isinstance(sets, dict): params = [value for value in sets.values()] sets = ",".join([f"{key}=?" for key in sets.keys()]) sql = f""" UPDATE {_generate_table(table)} SET {sets} {_generate_where(where)}""" return MrsDbExec(sql, params) def delete(table: str, where=[]) -> MrsDbExec: sql = f""" DELETE FROM {_generate_table(table)} {_generate_where(where)}""" return MrsDbExec(sql) def insert(table, values={}): params = [] place_holders = [] cols = [] if isinstance(values, list): cols = ",".join(values) place_holders = ",".join(["?" for val in values]) elif isinstance(values, dict): cols = ",".join([str(col) for col in values.keys()]) place_holders = ",".join(["?" for val in values.values()]) params = [val for val in values.values()] sql = f""" INSERT INTO {_generate_table(table)} ({cols}) VALUES ({place_holders}) """ return MrsDbExec(sql, params) def get_sequence_id(session): return ( MrsDbExec(f"SELECT {_generate_qualified_name('get_sequence_id()')} as id") .exec(session) .first["id"] ) class MrsDbSession: def __init__(self, **kwargs) -> None: self._session = get_current_session(kwargs.get("session")) self._exception_handler = kwargs.get("exception_handler") check_version = kwargs.get("check_version", True) if check_version and mrs_metadata_schema_exists(self._session): current_db_version = get_mrs_schema_version(self._session) if current_db_version[0] < 2: raise Exception( "This MySQL Shell version requires a new major version of the MRS metadata schema, " f"{general.DB_VERSION_STR}. The currently deployed schema version is " f"{'%d.%d.%d' % tuple(current_db_version)}. Please downgrade the MySQL Shell version " "or drop the MRS metadata schema and run `mrs.configure()`." ) def __enter__(self): return self._session def __exit__(self, exc_type, exc_value, exc_traceback): if exc_type is None: return if get_interactive_default() and self._exception_handler: return self._exception_handler(exc_type, exc_value, exc_traceback) return False @property def session(self): return self._session class MrsDbTransaction: def __init__(self, session) -> None: self._session = session def __enter__(self) -> "MrsDbTransaction": self._session.run_sql("START TRANSACTION") return self def __exit__(self, exc_type, exc_value, exc_traceback) -> bool: if exc_type is None: self._session.run_sql("COMMIT") return self._session.run_sql("ROLLBACK") return False def create_identification_conditions(id, name, id_context, name_col): """ Creates the necessary SQL WHERE conditions to identify an MRS object based given id and identification string. """ conditions = {} if id is not None: conditions["id"] = id if name is not None: conditions[name_col] = name return conditions def identify_target_object( session, service_conditions, schema_conditions, object_conditions ): """ Uses the given identification conditions for service, schema and object to uniquely identify a specific object, either service, schema or object. The function throws an error if either: - The conditions identify no object. - The conditions identify more than one object. Returns the type of identified object and its id. """ tables = [] target_object = "" id_field = "" conditions = [] params = [] # service table is included in the query either when service conditions are given # or when no conditions are given if service_conditions or (not schema_conditions and not object_conditions): tables.append("mysql_rest_service_metadata.service se") target_object = "service" id_field = "se.id" # schema table is included in the query whenever schema or object conditions are # given if schema_conditions or object_conditions: tables.append("mysql_rest_service_metadata.db_schema sc") if service_conditions: conditions.append("sc.service_id = se.id") target_object = "schema" id_field = "sc.id" # object table is included on the query when object conditions are given if object_conditions: tables.append("mysql_rest_service_metadata.db_object ob") if service_conditions or schema_conditions: conditions.append("ob.db_schema_id = sc.id") target_object = "object" id_field = "ob.id" if service_conditions: for column, value in service_conditions.items(): conditions.append(f"se.{column}=?") params.append(value) if schema_conditions: for column, value in schema_conditions.items(): conditions.append(f"sc.{column}=?") params.append(value) if object_conditions: for column, value in object_conditions.items(): conditions.append(f"ob.{column}=?") params.append(value) where = "" if conditions: cond_string = " AND ".join(conditions) where = f"WHERE {cond_string}" sql = f"""SELECT {id_field} FROM {" INNER JOIN ".join(tables)} { where} LIMIT 2""" result = session.run_sql(sql, params) rows = result.fetch_all() if len(rows) != 1: raise RuntimeError( f"Unable to identify a unique {target_object} for the operation." ) return target_object, rows[0][0] def get_session_uri(session): if "shell.Object" in str(type(session)): uri = session.get_uri() else: uri = session.session.get_uri() uri = uri.split("?")[0] return uri def uppercase_first_char(s): if len(s) > 0: return s[0].upper() + s[1:] return "" def convert_path_to_camel_case( path: str, allowed_special_characters: Optional[set[str]] = None ): if not allowed_special_characters: allowed_special_characters = set() if path.startswith("/"): path = path[1:] parts = path.replace("/", "_").split("_") s = parts[0] + "".join(uppercase_first_char(x) for x in parts[1:]) # Only return alphanumeric characters or those in the allow list return "".join(e for e in s if e.isalnum() or e in allowed_special_characters) def convert_path_to_pascal_case( path: str, allowed_special_characters: Optional[set[str]] = None ): return uppercase_first_char( convert_path_to_camel_case(path, allowed_special_characters) ) def convert_snake_to_camel_case(snake_str): snake_str = "".join(x.capitalize() for x in snake_str.lower().split("_")) return snake_str[0].lower() + snake_str[1:] def convert_to_snake_case(str): return re.sub(r"(?<!^)(?=[A-Z])", "_", str).lower() def unquote(name): # TODO- remove this, it doesn't work if name.startswith("`"): return name.strip("`") elif name.startswith('"'): return name.strip('"') return name def escape_str(s): return s.replace("\\", "\\\\").replace('"', '\\"').replace("'", "\\'") def quote_str(s): return '"' + escape_str(s) + '"' def unescape_str(s): return s.replace("\\'", "'").replace('\\"', '"').replace("\\\\", "\\") def unquote_str(s): if (s.startswith("'") and s.endswith("'")) or ( s.startswith('"') and s.endswith('"') ): return unescape_str(s[1:-1]) return s def quote_ident(s): return mysqlsh.mysql.quote_identifier(s) def unquote_ident(s): return mysqlsh.mysql.unquote_identifier(s) def squote_str(s): return "'" + escape_str(s) + "'" path_re = re.compile("^(/[a-zA-Z_0-9]*?)+?$") quote_text = squote_str quote_user = quote_ident quote_auth_app = quote_ident quote_role = quote_ident # full_service_path quote_fsp = lambda s: s # TODO review # request_path def quote_rpath(s): if not s or "*" in s or "?" in s or s[0] != "/": return quote_ident(s) if path_re.match(s): return s return quote_ident(s) def escape_wildcards(text: str) -> str: "escape * and ? wildcards with \\" return text.replace("\\", "\\\\").replace("*", "\\*").replace("?", "\\?") def unescape_wildcards(text: str) -> str: return text.replace("\\*", "*").replace("\\?", "?").replace("\\\\", "\\") def contains_wildcards(text: str) -> str: stripped = text.replace("\\\\", "").replace("\\?", "").replace("\\*", "") return "?" in stripped or "*" in stripped def get_enabled_status_caption(enabledState): if enabledState == 2: return "PRIVATE" if enabledState == 1 or enabledState is True: return "ENABLED" return "DISABLED" def format_result(result): if len(result) > 0: columns = list(result[0].keys()) # Get max_col_lengths max_lengths = {} # Initialize with column name lengths for col in columns: max_lengths[col] = len(col) # Loop over all rows and check if a field length is bigger for row in result: for col in columns: field = str(row.get(col)) current_length = max_lengths.get(col, 0) length = len(field) # If the field contains linebreaks, consider the longest line if "\n" in field: length = 0 for ln in field.split("\n"): if len(ln) > length: length = len(ln) if length > current_length: max_lengths[col] = length h_sep = "+" for col in columns: h_sep += "-" * (max_lengths[col] + 2) + "+" formatted_res = h_sep + "\n" + "|" for col in columns: formatted_res += f' {col}{" " * (max_lengths[col] - len(col))} |' formatted_res += "\n" + h_sep + "\n" for row in result: formatted_res += "|" for index, col in enumerate(columns): f = str(row.get(col)) # If there are linebreaks in the field, add each line and extend the grid if "\n" in f: pre = "| " post = " " for i, c in enumerate(columns): if i < index: pre += f'{" " * max_lengths[c]} | ' elif i > index: post += f'{" " * max_lengths[c]} | ' pre = pre[:-1] post = post[:-1] lines = f.split("\n") for ln_i, ln in enumerate(lines): if ln_i > 0: formatted_res += pre formatted_res += f' {ln}{" " * (max_lengths[col] - len(ln))} |' if ln_i < len(lines) - 1: formatted_res += post + "\n" else: formatted_res += f' {f}{" " * (max_lengths[col] - len(f))} |' formatted_res += "\n" return formatted_res + h_sep # To make handling easier, return an empty string if there are no rows so # the result does not have to be checked for None return "" def is_text(data: bytes) -> bool: if isinstance(data, str): data = data.encode() valid_text__chars = "".join( list(map(chr, range(32, 127))) + list("\n\r\t\b")) data_without_text = data.translate(None, valid_text__chars.encode()) # If there's a null character, then it's not a text string if 0 in data_without_text: return False # Check how many bytes are available after removing the ones that # are considered as text. if len(data_without_text) >= len(data) * 0.3: # if more then 30% if the characters are binary, then # take the data as binary return False return True def is_number(s): try: float(s) except ValueError: return False return True class _NotSet: # used to differentiate None (NULL) vs argument not set def __bool__(self): return False NotSet = _NotSet()