unittest/scripts/setup_py/setup.py (617 lines of code) (raw):

import queue import time import threading from __future__ import print_function import difflib import hashlib import mysqlsh import os.path import random import re import string import sys import inspect mysqlshrec = "mysqlshrec" if __os_type == "windows": mysqlshrec = mysqlshrec + ".exe" mysqlshrec = os.path.join(__bin_dir, mysqlshrec) k_cmdline_password_insecure_msg = "Using a password on the command line interface can be insecure." def get_members(object): all_exports = dir(object) exports = [] for member in all_exports: if not member.startswith('__'): exports.append(member) return exports ## # Verifies if a variable is defined, returning true or false accordingly # @param cb An anonymous function that simply executes the variable to be # verified, example: # # defined(lambda:myVar) # # Will return True if myVar is defined or False if not. ## def defined(cb): try: cb() return True except: return False def create_root_from_anywhere(session): session.run_sql("SET SQL_LOG_BIN=0") session.run_sql("CREATE USER root@'%' IDENTIFIED BY 'root'") session.run_sql("GRANT ALL ON *.* to root@'%' WITH GRANT OPTION") session.run_sql("SET SQL_LOG_BIN=1") def has_ssh_environment(): import os if "SSH_URI" in os.environ and "MYSQL_OVER_SSH_URI" in os.environ and "SSH_USER_URI" in os.environ: return True return False def has_oci_environment(context): if context not in ['OS', 'MDS']: return False variables = ['OCI_CONFIG_HOME', 'OCI_COMPARTMENT_ID', 'OS_NAMESPACE', 'OS_BUCKET_NAME'] if context == 'MDS': variables = variables + ['MDS_URI'] missing = [] for variable in variables: if variable not in globals(): missing.append(variable) if (len(missing)): sys.stderr.write("Missing Variables: {}".format(", ".join(missing))) return False return True def has_aws_environment(): variables = ['MYSQLSH_S3_BUCKET_NAME'] missing = [] g = globals() for variable in variables: if variable not in g: missing.append(variable) if len(missing): sys.stderr.write("Missing AWS Variables: {}".format(", ".join(missing))) return False return True def is_re_instance(o): return isinstance(o, is_re_instance.__re_type) is_re_instance.__re_type = type(re.compile('')) class __TextMatcher: def __init__(self, o): self.__is_re = is_re_instance(o) if not self.__is_re and not isinstance(o, str): raise Exception( "Expected str or re.Pattern, but got: " + str(type(o))) self.__o = o def __str__(self): if self.__is_re: return "re.Pattern('" + self.__o.pattern + "')" else: return self.__o def matches(self, s): if self.__is_re: return self.__o.search(s) is not None else: return s.find(self.__o) != -1 def __caller_context(): # 0 is here, 1 is the EXPECT_ call, 2 is the test code calling EXPECT_ frame = inspect.stack()[2] return f"{frame}" def EXPECT_EQ(expected, actual, note=""): if expected != actual: if not note: note = __caller_context() context = "Tested values don't match as expected: "+note + \ "\n\tActual: " + str(actual) + "\n\tExpected: " + str(expected) testutil.fail(context) def EXPECT_NE(expected, actual, note=""): if expected == actual: if not note: note = __caller_context() context = "Tested values should not match: "+note + \ "\n\tActual: " + str(actual) + "\n\tExpected: " + str(expected) testutil.fail(context) def EXPECT_EQ_TEXT(expected, actual, note=""): assert isinstance(expected, list) assert isinstance(actual, list) if expected != actual: if not note: note = __caller_context() import difflib context = "Tested values don't match as expected: "+note + \ "\n".join(difflib.context_diff(expected, actual, fromfile="Expected", tofile="Actual")) testutil.fail(context) def EXPECT_IN(actual, expected_values, note=""): if actual not in expected_values: if not note: note = __caller_context() context = "Tested value not one of expected: "+note+"\n\tActual: " + \ str(actual) + "\n\tExpected: " + str(expected_values) testutil.fail(context) def EXPECT_CONTAINS(expected, actual, note=""): missing = [] for e in expected: if e not in actual: missing.append(e) if missing: if not note: note = __caller_context() context = "Tested values not as expected: "+note+"\n\tActual: " + \ ", ".join(actual) + "\n\tExpected: " + ", ".join(expected) + \ "\n\tMissing: " + ", ".join(missing) testutil.fail(context) def EXPECT_CONTAINS_LIKE(expected, actual, note=""): missing = [] for e in expected: if not list(filter(lambda a: re.match(e, a), actual)): missing.append(e) if missing: if not note: note = __caller_context() context = "Tested values not as expected: "+note+"\n\tActual: " + \ ", ".join(actual) + "\n\tExpected: " + ", ".join(expected) + \ "\n\tMissing: " + ", ".join(missing) testutil.fail(context) return False return True def EXPECT_NOT_CONTAINS(expected, actual, note=""): present = [] for e in expected: if e in actual: present.append(e) if present: if not note: note = __caller_context() context = "Tested values not as expected: "+note+"\n\tActual: " + \ ", ".join(actual) + "\n\tNot expected: " + ", ".join(expected) + \ "\n\tPresent: " + ", ".join(present) testutil.fail(context) def EXPECT_LE(expected, actual, note=""): if expected > actual: if not note: note = __caller_context() context = "Tested values not as expected: "+note+"\n\t" + \ str(expected)+" (expected) <= "+str(actual)+" (actual)" testutil.fail(context) def EXPECT_LT(expected, actual, note=""): if expected >= actual: if not note: note = __caller_context() context = "Tested values not as expected: "+note+"\n\t" + \ str(expected)+" (expected) < "+str(actual)+" (actual)" testutil.fail(context) def EXPECT_GE(expected, actual, note=""): if expected < actual: if not note: note = __caller_context() context = "Tested values not as expected: "+note+"\n\t" + \ str(expected)+" (expected) >= "+str(actual)+" (actual)" testutil.fail(context) def EXPECT_GT(expected, actual, note=""): if expected <= actual: if not note: note = __caller_context() context = "Tested values not as expected: "+note+"\n\t" + \ str(expected)+" (expected) > "+str(actual)+" (actual)" testutil.fail(context) def EXPECT_BETWEEN(expected_from, expected_to, actual, note=""): if expected_from <= actual <= expected_to: pass else: if not note: note = __caller_context() context = "Tested value not as expected: "+note + \ f"\n\t{expected_from} <= {actual} <= {expected_to}" testutil.fail(context) def EXPECT_NOT_BETWEEN(expected_from, expected_to, actual, note=""): if expected_from <= actual <= expected_to: if not note: note = __caller_context() context = "Tested value not as expected: "+note + \ f"\n\tNOT ({expected_from} <= {actual} <= {expected_to})" testutil.fail(context) def EXPECT_DELTA(expected, allowed_delta, actual, note=""): if not note: note = __caller_context() EXPECT_BETWEEN(expected - allowed_delta, expected + allowed_delta, actual, note) def EXPECT_TRUE(value, note=""): if not value: if not note: note = __caller_context() context = f"Tested value '{value}' expected to be true but is false" if note: context += ": "+note testutil.fail(context) return False return True def EXPECT_FALSE(value, note=""): if value: if not note: note = __caller_context() context = f"Tested value '{value}' expected to be false but is true" if note: context += ": "+note testutil.fail(context) def EXPECT_THROWS(func, etext, note=""): if note: note = note+": " assert callable(func) m = __TextMatcher(etext) try: func() testutil.fail( note+"<red>Missing expected exception throw like " + str(m) + "</red>") return False except Exception as e: exception_message = type(e).__name__ + ": " + str(e) if not m.matches(exception_message): testutil.fail(note+"<red>Exception expected:</red> " + str(m) + "\n\t<yellow>Actual:</yellow> " + exception_message) return False return True def EXPECT_MAY_THROW(func, etext): assert callable(func) m = __TextMatcher(etext) ret = None try: ret = func() except Exception as e: exception_message = type(e).__name__ + ": " + str(e) if not m.matches(exception_message): testutil.fail("<red>Exception expected:</red> " + str(m) + "\n\t<yellow>Actual:</yellow> " + exception_message) return ret def EXPECT_NO_THROWS(func, context=""): assert callable(func) try: return func() except Exception as e: testutil.fail("<b>Context:</b> " + __test_context + "\n<red>Unexpected exception thrown (" + context + "): " + str(e) + "</red>") def EXPECT_STDOUT_CONTAINS(text, note=None): out = testutil.fetch_captured_stdout(False) err = testutil.fetch_captured_stderr(False) if out.find(text) == -1: if not note: note = __caller_context() context = "<b>Context:</b> " + __test_context + "\n<red>Missing output:</red> " + text + \ "\n<yellow>Actual stdout:</yellow> " + out + \ "\n<yellow>Actual stderr:</yellow> " + err testutil.fail(context) def EXPECT_STDERR_CONTAINS(text, note=None): out = testutil.fetch_captured_stdout(False) err = testutil.fetch_captured_stderr(False) if err.find(text) == -1: if not note: note = __caller_context() context = "<b>Context:</b> " + __test_context + "\n<red>Missing output:</red> " + text + \ "\n<yellow>Actual stdout:</yellow> " + out + \ "\n<yellow>Actual stderr:</yellow> " + err testutil.fail(context) def WIPE_STDOUT(): line = testutil.fetch_captured_stdout(True) while line != "": line = testutil.fetch_captured_stdout(True) def WIPE_STDERR(): line = testutil.fetch_captured_stderr(True) while line != "": line = testutil.fetch_captured_stderr(True) def WIPE_OUTPUT(): WIPE_STDOUT() WIPE_STDERR() def WIPE_SHELL_LOG(): testutil.wipe_file_contents(testutil.get_shell_log_path()) def EXPECT_SHELL_LOG_CONTAINS(text, note=None): log_file = testutil.get_shell_log_path() match_list = testutil.grep_file(log_file, text) if len(match_list) == 0: log_out = testutil.cat_file(log_file) testutil.fail(f"<b>Context:</b> {__test_context}\n<red>Missing log output:</red> {text}\n<yellow>Actual log output:</yellow> {log_out}") def EXPECT_SHELL_LOG_CONTAINS_COUNT(text, count): if not isinstance(count, int): raise TypeError('"count" argument must be a number.') log_file = testutil.get_shell_log_path() match_list = testutil.grep_file(log_file, text) if len(match_list) != count: log_out = testutil.cat_file(log_file) testutil.fail( f"<b>Context:</b> {__test_context}\n<red>Missing log output:</red> {text}\n<yellow>Actual log output:</yellow> {log_out}") def EXPECT_SHELL_LOG_NOT_CONTAINS(text, note=None): log_file = testutil.get_shell_log_path() match_list = testutil.grep_file(log_file, text) if len(match_list) != 0: if not note: note = __caller_context() log_out = testutil.cat_file(log_file) testutil.fail( f"<b>Context:</b> {__test_context}\n<red>Unexpected log output:</red> {text}\n<yellow>Actual log output:</yellow> {log_out}") def EXPECT_SHELL_LOG_MATCHES(re, note=None): log_file = testutil.get_shell_log_path() with open(log_file, "r", encoding="utf-8") as f: log_out = f.read() if re.search(log_out) is None: if not note: note = __caller_context() testutil.fail( f"<b>Context:</b> {__test_context}\n<red>Missing match for:</red> {re.pattern}\n<yellow>Actual log output:</yellow> {log_out}") def EXPECT_STDOUT_MATCHES(re, note=None): out = testutil.fetch_captured_stdout(False) err = testutil.fetch_captured_stderr(False) if re.search(out) is None: if not note: note = __caller_context() context = "<b>Context:</b> " + __test_context + "\n<red>Missing match for:</red> " + re.pattern + \ "\n<yellow>Actual stdout:</yellow> " + out + \ "\n<yellow>Actual stderr:</yellow> " + err testutil.fail(context) def EXPECT_STDOUT_NOT_CONTAINS(text, note=None): out = testutil.fetch_captured_stdout(False) err = testutil.fetch_captured_stderr(False) if out.find(text) != -1: if not note: note = __caller_context() context = "<b>Context:</b> " + __test_context + "\n<red>Unexpected output:</red> " + text + \ "\n<yellow>Actual stdout:</yellow> " + out + \ "\n<yellow>Actual stderr:</yellow> " + err testutil.fail(context) def EXPECT_FILE_CONTAINS(expected, path, note=None): with open(path, encoding='utf-8') as f: contents = f.read() if contents.find(expected) == -1: if not note: note = __caller_context() context = "<b>Context:</b> " + __test_context + "\n<red>Missing contents:</red> " + expected + \ "\n<yellow>Actual contents:</yellow> " + \ contents + "\n<yellow>File:</yellow> " + path testutil.fail(context) def EXPECT_FILE_MATCHES(re, path, note=None): with open(path, encoding='utf-8') as f: contents = f.read() if re.search(contents) is None: if not note: note = __caller_context() testutil.fail( f"<b>Context:</b> {__test_context}\n<red>Missing match for:</red> {re.pattern}\n<yellow>Actual contents:</yellow> {contents}\n<yellow>File:</yellow> {path}") def EXPECT_FILE_NOT_CONTAINS(expected, path, note=None): with open(path, encoding='utf-8') as f: contents = f.read() if contents.find(expected) != -1: if not note: note = __caller_context() context = "<b>Context:</b> " + __test_context + "\n<red>Unexpected contents:</red> " + \ expected + "\n<yellow>Actual contents:</yellow> " + \ contents + "\n<yellow>File:</yellow> " + path testutil.fail(context) def EXPECT_FILE_NOT_MATCHES(re, path, note=None): with open(path, encoding='utf-8') as f: contents = f.read() if re.search(contents) is not None: if not note: note = __caller_context() testutil.fail( f"<b>Context:</b> {__test_context}\n<red>Unexpected match for:</red> {re.pattern}\n<yellow>Actual contents:</yellow> {contents}\n<yellow>File:</yellow> {path}") def validate_crud_functions(crud, expected): actual = dir(crud) # Ensures expected functions are on the actual list missing = [] for exp_funct in expected: try: pos = actual.index(exp_funct) actual.remove(exp_funct) except: missing.append(exp_funct) if len(missing) == 0: print("All expected functions are available\n") else: print("Missing Functions:", missing) if len(actual) == 0 or (len(actual) == 1 and actual[0] == 'help'): print("No additional functions are available\n") else: print("Extra Functions:", actual) def validate_members(object, expected_members): all_members = dir(object) # Remove the python built in members members = [] for member in all_members: if not member.startswith('__'): members.append(member) missing = [] for expected in expected_members: try: index = members.index(expected) members.remove(expected) except: missing.append(expected) errors = [] error = "" if len(members): error = "Unexpected Members: %s" % ', '.join(members) errors.append(error) error = "" if len(missing): error = "Missing Members: %s" % ', '.join(missing) errors.append(error) if len(errors): testutil.fail(', '.join(errors)) def print_differences(source, target): src_lines = [] tgt_lines = [] with open(source) as f: src_lines = f.readlines() with open(target) as f: tgt_lines = f.readlines() for line in difflib.context_diff(src_lines, tgt_lines, fromfile=source, tofile=target): testutil.dprint(line) def ensure_plugin_enabled(plugin_name, session, plugin_soname=None): if plugin_soname is None: plugin_soname = plugin_name os = session.run_sql('select @@version_compile_os').fetch_one()[0] if os == "Win32" or os == "Win64": ext = "dll" else: ext = "so" try: session.run_sql("INSTALL PLUGIN {0} SONAME '{1}.{2}';".format( plugin_name, plugin_soname, ext)) except mysqlsh.DBError as e: if 1125 != e.code: raise e def ensure_plugin_disabled(plugin_name, session): is_installed = session.run_sql( "SELECT COUNT(1) FROM INFORMATION_SCHEMA.PLUGINS WHERE PLUGIN_NAME LIKE '" + plugin_name + "';").fetch_one()[0] if is_installed: session.run_sql("UNINSTALL PLUGIN " + plugin_name + ";") # Starting 8.0.24 the client lib started reporting connection error using # host:port format, previous versions used just the host. # # This function is used to format the host description accordingly. def libmysql_host_description(hostname, port): if testutil.version_check(__libmysql_version_id, ">", "8.0.23"): return hostname + ":" + str(port) return hostname def get_socket_path(session, uri=None): if uri is None: uri = session.uri row = session.run_sql( f"SELECT @@{'socket' if 'mysql' == shell.parse_uri(uri).scheme else 'mysqlx_socket'}, @@datadir").fetch_one() if row[0][0] == '/' or __os_type == "windows": p = row[0] else: p = os.path.join(row[1], row[0]) if len(p) > 100: testutil.fail("socket path is too long (>100): " + p) return p def get_socket_uri(session, uri=None): if uri is None: uri = session.uri parsed = shell.parse_uri(uri) if "password" not in parsed: parsed["password"] = "" new_uri = {} for key in ["scheme", "user", "password"]: new_uri[key] = parsed[key] new_uri["socket"] = get_socket_path(session, uri) return shell.unparse_uri(new_uri) def random_string(lower, upper=None): if upper is None: upper = lower return ''.join(random.choices(string.ascii_letters + string.digits, k=random.randint(lower, upper))) def random_email(): return random_string(10, 40) + "@" + random_string(10, 40) + "." + random_string(3) def md5sum(s): return hashlib.md5(s.encode("utf-8")).hexdigest() def reset_instance(session): try: session.run_sql("SELECT group_replication_reset_member_actions()") except: pass try: session.run_sql("SELECT asynchronous_connection_failover_reset()") except: pass session.run_sql("STOP SLAVE") try: session.run_sql("STOP group_replication") session.run_sql("SET PERSIST group_replication_start_on_boot=0") except: pass session.run_sql("SET GLOBAL super_read_only=0") session.run_sql("SET GLOBAL read_only=0") session.run_sql("DROP SCHEMA IF EXISTS mysql_innodb_cluster_metadata") r = session.run_sql("SHOW SCHEMAS") rows = r.fetch_all() for row in rows: if row[0] in ["mysql", "performance_schema", "sys", "information_schema"]: continue session.run_sql("DROP SCHEMA "+row[0]) r = session.run_sql("SELECT user,host FROM mysql.user") rows = r.fetch_all() for row in rows: if row[0] in ["mysql.sys", "mysql.session", "mysql.infoschema"]: continue if row[0] == "root" and (row[1] == "localhost" or row[1] == "%"): continue session.run_sql("DROP USER ?@?", [row[0], row[1]]) session.run_sql("RESET MASTER") session.run_sql("RESET SLAVE ALL") def reset_multi(ports): testutil.stop_group(ports) for p in ports: s = mysql.get_session(f"mysql://root:root@localhost:{p}") reset_instance(s) s.close() class Docker_manipulator: def __init__(self, docker_name, data_path): self.data_path = data_path self.dockerfile = os.path.join( self.data_path, docker_name, "Dockerfile") self.docker_name = docker_name self.docker_image_name = "mysql-sshd-shell-{}".format(docker_name) if not os.path.exists(self.dockerfile): testutil.fail( "The Dockerfile is missing for {}".format(docker_name)) self.hash = self.__get_hash() import docker self.client = docker.from_env() self.container = None def __get_hash(self): import hashlib with open(self.dockerfile, "rb") as f: return hashlib.sha256(f.read()).hexdigest() def __pre_run_check(self): # before runnig it check if there's already such container if so, kill it and start again try: print("About to find old container") c = self.client.containers.get(self.docker_name) if c.status == "running": c.kill() self.client.containers.prune() print("Old container found and removed") except: # this should be fine as containers.get will throw when container is not found pass try: print("Checking image tag") self.img = self.client.images.get(self.docker_image_name) if self.img.attrs.get("ContainerConfig").get("Labels").get("hash") != self.hash: self.client.images.remove(image=self.img.id, force=True) print("Old image found which is outdated, removed") except Exception as e: pass def __build_and_run(self): try: self.img except: self.img = self.client.images.build(path=os.path.join( self.data_path, self.docker_name), tag=self.docker_image_name, labels={"hash": self.hash})[0] print("Build new docker image") self.container = self.client.containers.run(image=self.img.id, detach=True, name=self.docker_name, environment=["MYSQL_ROOT_PASSWORD=sandbox"], ports={"2222/tcp": 2222}) print("Started new container") def log_watcher(self, log_fetcher): while True: try: self._queue.put_nowait(log_fetcher.next()) except: break def run(self): self._queue = queue.Queue(10) self.__pre_run_check() self.__build_and_run() log_iter = self.container.logs(stream=True) watcher = threading.Thread(target=self.log_watcher, args=(log_iter,)) watcher.start() while True: if self.client.containers.get(self.container.id).status == "running": break else: time.sleep(1) while True: try: line = self._queue.get(block=True, timeout=10) match = re.search( br"port: ([0-9]{4})\s MySQL Community Server - GPL", line) if match: print("Found port: {:g}".format(match.groups()[0])) log_iter.close() break except Exception as e: break print("Waiting for container to fully start") log_iter.close() watcher.join() print("Container is ready") def cleanup(self, remove_image=False): if self.container is not None: self.container.kill() # kill the container # the python client.images.build leave a mess, clean it self.client.containers.prune() if remove_image: self.client.images.remove(image=self.img.id, force=True)