unittest/scripts/setup_py/setup.py (617 lines of code) (raw):
import queue
import time
import threading
from __future__ import print_function
import difflib
import hashlib
import mysqlsh
import os.path
import random
import re
import string
import sys
import inspect
mysqlshrec = "mysqlshrec"
if __os_type == "windows":
mysqlshrec = mysqlshrec + ".exe"
mysqlshrec = os.path.join(__bin_dir, mysqlshrec)
k_cmdline_password_insecure_msg = "Using a password on the command line interface can be insecure."
def get_members(object):
all_exports = dir(object)
exports = []
for member in all_exports:
if not member.startswith('__'):
exports.append(member)
return exports
##
# Verifies if a variable is defined, returning true or false accordingly
# @param cb An anonymous function that simply executes the variable to be
# verified, example:
#
# defined(lambda:myVar)
#
# Will return True if myVar is defined or False if not.
##
def defined(cb):
try:
cb()
return True
except:
return False
def create_root_from_anywhere(session):
session.run_sql("SET SQL_LOG_BIN=0")
session.run_sql("CREATE USER root@'%' IDENTIFIED BY 'root'")
session.run_sql("GRANT ALL ON *.* to root@'%' WITH GRANT OPTION")
session.run_sql("SET SQL_LOG_BIN=1")
def has_ssh_environment():
import os
if "SSH_URI" in os.environ and "MYSQL_OVER_SSH_URI" in os.environ and "SSH_USER_URI" in os.environ:
return True
return False
def has_oci_environment(context):
if context not in ['OS', 'MDS']:
return False
variables = ['OCI_CONFIG_HOME',
'OCI_COMPARTMENT_ID',
'OS_NAMESPACE',
'OS_BUCKET_NAME']
if context == 'MDS':
variables = variables + ['MDS_URI']
missing = []
for variable in variables:
if variable not in globals():
missing.append(variable)
if (len(missing)):
sys.stderr.write("Missing Variables: {}".format(", ".join(missing)))
return False
return True
def has_aws_environment():
variables = ['MYSQLSH_S3_BUCKET_NAME']
missing = []
g = globals()
for variable in variables:
if variable not in g:
missing.append(variable)
if len(missing):
sys.stderr.write("Missing AWS Variables: {}".format(", ".join(missing)))
return False
return True
def is_re_instance(o):
return isinstance(o, is_re_instance.__re_type)
is_re_instance.__re_type = type(re.compile(''))
class __TextMatcher:
def __init__(self, o):
self.__is_re = is_re_instance(o)
if not self.__is_re and not isinstance(o, str):
raise Exception(
"Expected str or re.Pattern, but got: " + str(type(o)))
self.__o = o
def __str__(self):
if self.__is_re:
return "re.Pattern('" + self.__o.pattern + "')"
else:
return self.__o
def matches(self, s):
if self.__is_re:
return self.__o.search(s) is not None
else:
return s.find(self.__o) != -1
def __caller_context():
# 0 is here, 1 is the EXPECT_ call, 2 is the test code calling EXPECT_
frame = inspect.stack()[2]
return f"{frame}"
def EXPECT_EQ(expected, actual, note=""):
if expected != actual:
if not note:
note = __caller_context()
context = "Tested values don't match as expected: "+note + \
"\n\tActual: " + str(actual) + "\n\tExpected: " + str(expected)
testutil.fail(context)
def EXPECT_NE(expected, actual, note=""):
if expected == actual:
if not note:
note = __caller_context()
context = "Tested values should not match: "+note + \
"\n\tActual: " + str(actual) + "\n\tExpected: " + str(expected)
testutil.fail(context)
def EXPECT_EQ_TEXT(expected, actual, note=""):
assert isinstance(expected, list)
assert isinstance(actual, list)
if expected != actual:
if not note:
note = __caller_context()
import difflib
context = "Tested values don't match as expected: "+note + \
"\n".join(difflib.context_diff(expected, actual,
fromfile="Expected", tofile="Actual"))
testutil.fail(context)
def EXPECT_IN(actual, expected_values, note=""):
if actual not in expected_values:
if not note:
note = __caller_context()
context = "Tested value not one of expected: "+note+"\n\tActual: " + \
str(actual) + "\n\tExpected: " + str(expected_values)
testutil.fail(context)
def EXPECT_CONTAINS(expected, actual, note=""):
missing = []
for e in expected:
if e not in actual:
missing.append(e)
if missing:
if not note:
note = __caller_context()
context = "Tested values not as expected: "+note+"\n\tActual: " + \
", ".join(actual) + "\n\tExpected: " + ", ".join(expected) + \
"\n\tMissing: " + ", ".join(missing)
testutil.fail(context)
def EXPECT_CONTAINS_LIKE(expected, actual, note=""):
missing = []
for e in expected:
if not list(filter(lambda a: re.match(e, a), actual)):
missing.append(e)
if missing:
if not note:
note = __caller_context()
context = "Tested values not as expected: "+note+"\n\tActual: " + \
", ".join(actual) + "\n\tExpected: " + ", ".join(expected) + \
"\n\tMissing: " + ", ".join(missing)
testutil.fail(context)
return False
return True
def EXPECT_NOT_CONTAINS(expected, actual, note=""):
present = []
for e in expected:
if e in actual:
present.append(e)
if present:
if not note:
note = __caller_context()
context = "Tested values not as expected: "+note+"\n\tActual: " + \
", ".join(actual) + "\n\tNot expected: " + ", ".join(expected) + \
"\n\tPresent: " + ", ".join(present)
testutil.fail(context)
def EXPECT_LE(expected, actual, note=""):
if expected > actual:
if not note:
note = __caller_context()
context = "Tested values not as expected: "+note+"\n\t" + \
str(expected)+" (expected) <= "+str(actual)+" (actual)"
testutil.fail(context)
def EXPECT_LT(expected, actual, note=""):
if expected >= actual:
if not note:
note = __caller_context()
context = "Tested values not as expected: "+note+"\n\t" + \
str(expected)+" (expected) < "+str(actual)+" (actual)"
testutil.fail(context)
def EXPECT_GE(expected, actual, note=""):
if expected < actual:
if not note:
note = __caller_context()
context = "Tested values not as expected: "+note+"\n\t" + \
str(expected)+" (expected) >= "+str(actual)+" (actual)"
testutil.fail(context)
def EXPECT_GT(expected, actual, note=""):
if expected <= actual:
if not note:
note = __caller_context()
context = "Tested values not as expected: "+note+"\n\t" + \
str(expected)+" (expected) > "+str(actual)+" (actual)"
testutil.fail(context)
def EXPECT_BETWEEN(expected_from, expected_to, actual, note=""):
if expected_from <= actual <= expected_to:
pass
else:
if not note:
note = __caller_context()
context = "Tested value not as expected: "+note + \
f"\n\t{expected_from} <= {actual} <= {expected_to}"
testutil.fail(context)
def EXPECT_NOT_BETWEEN(expected_from, expected_to, actual, note=""):
if expected_from <= actual <= expected_to:
if not note:
note = __caller_context()
context = "Tested value not as expected: "+note + \
f"\n\tNOT ({expected_from} <= {actual} <= {expected_to})"
testutil.fail(context)
def EXPECT_DELTA(expected, allowed_delta, actual, note=""):
if not note:
note = __caller_context()
EXPECT_BETWEEN(expected - allowed_delta, expected +
allowed_delta, actual, note)
def EXPECT_TRUE(value, note=""):
if not value:
if not note:
note = __caller_context()
context = f"Tested value '{value}' expected to be true but is false"
if note:
context += ": "+note
testutil.fail(context)
return False
return True
def EXPECT_FALSE(value, note=""):
if value:
if not note:
note = __caller_context()
context = f"Tested value '{value}' expected to be false but is true"
if note:
context += ": "+note
testutil.fail(context)
def EXPECT_THROWS(func, etext, note=""):
if note:
note = note+": "
assert callable(func)
m = __TextMatcher(etext)
try:
func()
testutil.fail(
note+"<red>Missing expected exception throw like " + str(m) + "</red>")
return False
except Exception as e:
exception_message = type(e).__name__ + ": " + str(e)
if not m.matches(exception_message):
testutil.fail(note+"<red>Exception expected:</red> " + str(m) +
"\n\t<yellow>Actual:</yellow> " + exception_message)
return False
return True
def EXPECT_MAY_THROW(func, etext):
assert callable(func)
m = __TextMatcher(etext)
ret = None
try:
ret = func()
except Exception as e:
exception_message = type(e).__name__ + ": " + str(e)
if not m.matches(exception_message):
testutil.fail("<red>Exception expected:</red> " + str(m) +
"\n\t<yellow>Actual:</yellow> " + exception_message)
return ret
def EXPECT_NO_THROWS(func, context=""):
assert callable(func)
try:
return func()
except Exception as e:
testutil.fail("<b>Context:</b> " + __test_context +
"\n<red>Unexpected exception thrown (" + context + "): " + str(e) + "</red>")
def EXPECT_STDOUT_CONTAINS(text, note=None):
out = testutil.fetch_captured_stdout(False)
err = testutil.fetch_captured_stderr(False)
if out.find(text) == -1:
if not note:
note = __caller_context()
context = "<b>Context:</b> " + __test_context + "\n<red>Missing output:</red> " + text + \
"\n<yellow>Actual stdout:</yellow> " + out + \
"\n<yellow>Actual stderr:</yellow> " + err
testutil.fail(context)
def EXPECT_STDERR_CONTAINS(text, note=None):
out = testutil.fetch_captured_stdout(False)
err = testutil.fetch_captured_stderr(False)
if err.find(text) == -1:
if not note:
note = __caller_context()
context = "<b>Context:</b> " + __test_context + "\n<red>Missing output:</red> " + text + \
"\n<yellow>Actual stdout:</yellow> " + out + \
"\n<yellow>Actual stderr:</yellow> " + err
testutil.fail(context)
def WIPE_STDOUT():
line = testutil.fetch_captured_stdout(True)
while line != "":
line = testutil.fetch_captured_stdout(True)
def WIPE_STDERR():
line = testutil.fetch_captured_stderr(True)
while line != "":
line = testutil.fetch_captured_stderr(True)
def WIPE_OUTPUT():
WIPE_STDOUT()
WIPE_STDERR()
def WIPE_SHELL_LOG():
testutil.wipe_file_contents(testutil.get_shell_log_path())
def EXPECT_SHELL_LOG_CONTAINS(text, note=None):
log_file = testutil.get_shell_log_path()
match_list = testutil.grep_file(log_file, text)
if len(match_list) == 0:
log_out = testutil.cat_file(log_file)
testutil.fail(f"<b>Context:</b> {__test_context}\n<red>Missing log output:</red> {text}\n<yellow>Actual log output:</yellow> {log_out}")
def EXPECT_SHELL_LOG_CONTAINS_COUNT(text, count):
if not isinstance(count, int):
raise TypeError('"count" argument must be a number.')
log_file = testutil.get_shell_log_path()
match_list = testutil.grep_file(log_file, text)
if len(match_list) != count:
log_out = testutil.cat_file(log_file)
testutil.fail(
f"<b>Context:</b> {__test_context}\n<red>Missing log output:</red> {text}\n<yellow>Actual log output:</yellow> {log_out}")
def EXPECT_SHELL_LOG_NOT_CONTAINS(text, note=None):
log_file = testutil.get_shell_log_path()
match_list = testutil.grep_file(log_file, text)
if len(match_list) != 0:
if not note:
note = __caller_context()
log_out = testutil.cat_file(log_file)
testutil.fail(
f"<b>Context:</b> {__test_context}\n<red>Unexpected log output:</red> {text}\n<yellow>Actual log output:</yellow> {log_out}")
def EXPECT_SHELL_LOG_MATCHES(re, note=None):
log_file = testutil.get_shell_log_path()
with open(log_file, "r", encoding="utf-8") as f:
log_out = f.read()
if re.search(log_out) is None:
if not note:
note = __caller_context()
testutil.fail(
f"<b>Context:</b> {__test_context}\n<red>Missing match for:</red> {re.pattern}\n<yellow>Actual log output:</yellow> {log_out}")
def EXPECT_STDOUT_MATCHES(re, note=None):
out = testutil.fetch_captured_stdout(False)
err = testutil.fetch_captured_stderr(False)
if re.search(out) is None:
if not note:
note = __caller_context()
context = "<b>Context:</b> " + __test_context + "\n<red>Missing match for:</red> " + re.pattern + \
"\n<yellow>Actual stdout:</yellow> " + out + \
"\n<yellow>Actual stderr:</yellow> " + err
testutil.fail(context)
def EXPECT_STDOUT_NOT_CONTAINS(text, note=None):
out = testutil.fetch_captured_stdout(False)
err = testutil.fetch_captured_stderr(False)
if out.find(text) != -1:
if not note:
note = __caller_context()
context = "<b>Context:</b> " + __test_context + "\n<red>Unexpected output:</red> " + text + \
"\n<yellow>Actual stdout:</yellow> " + out + \
"\n<yellow>Actual stderr:</yellow> " + err
testutil.fail(context)
def EXPECT_FILE_CONTAINS(expected, path, note=None):
with open(path, encoding='utf-8') as f:
contents = f.read()
if contents.find(expected) == -1:
if not note:
note = __caller_context()
context = "<b>Context:</b> " + __test_context + "\n<red>Missing contents:</red> " + expected + \
"\n<yellow>Actual contents:</yellow> " + \
contents + "\n<yellow>File:</yellow> " + path
testutil.fail(context)
def EXPECT_FILE_MATCHES(re, path, note=None):
with open(path, encoding='utf-8') as f:
contents = f.read()
if re.search(contents) is None:
if not note:
note = __caller_context()
testutil.fail(
f"<b>Context:</b> {__test_context}\n<red>Missing match for:</red> {re.pattern}\n<yellow>Actual contents:</yellow> {contents}\n<yellow>File:</yellow> {path}")
def EXPECT_FILE_NOT_CONTAINS(expected, path, note=None):
with open(path, encoding='utf-8') as f:
contents = f.read()
if contents.find(expected) != -1:
if not note:
note = __caller_context()
context = "<b>Context:</b> " + __test_context + "\n<red>Unexpected contents:</red> " + \
expected + "\n<yellow>Actual contents:</yellow> " + \
contents + "\n<yellow>File:</yellow> " + path
testutil.fail(context)
def EXPECT_FILE_NOT_MATCHES(re, path, note=None):
with open(path, encoding='utf-8') as f:
contents = f.read()
if re.search(contents) is not None:
if not note:
note = __caller_context()
testutil.fail(
f"<b>Context:</b> {__test_context}\n<red>Unexpected match for:</red> {re.pattern}\n<yellow>Actual contents:</yellow> {contents}\n<yellow>File:</yellow> {path}")
def validate_crud_functions(crud, expected):
actual = dir(crud)
# Ensures expected functions are on the actual list
missing = []
for exp_funct in expected:
try:
pos = actual.index(exp_funct)
actual.remove(exp_funct)
except:
missing.append(exp_funct)
if len(missing) == 0:
print("All expected functions are available\n")
else:
print("Missing Functions:", missing)
if len(actual) == 0 or (len(actual) == 1 and actual[0] == 'help'):
print("No additional functions are available\n")
else:
print("Extra Functions:", actual)
def validate_members(object, expected_members):
all_members = dir(object)
# Remove the python built in members
members = []
for member in all_members:
if not member.startswith('__'):
members.append(member)
missing = []
for expected in expected_members:
try:
index = members.index(expected)
members.remove(expected)
except:
missing.append(expected)
errors = []
error = ""
if len(members):
error = "Unexpected Members: %s" % ', '.join(members)
errors.append(error)
error = ""
if len(missing):
error = "Missing Members: %s" % ', '.join(missing)
errors.append(error)
if len(errors):
testutil.fail(', '.join(errors))
def print_differences(source, target):
src_lines = []
tgt_lines = []
with open(source) as f:
src_lines = f.readlines()
with open(target) as f:
tgt_lines = f.readlines()
for line in difflib.context_diff(src_lines, tgt_lines, fromfile=source, tofile=target):
testutil.dprint(line)
def ensure_plugin_enabled(plugin_name, session, plugin_soname=None):
if plugin_soname is None:
plugin_soname = plugin_name
os = session.run_sql('select @@version_compile_os').fetch_one()[0]
if os == "Win32" or os == "Win64":
ext = "dll"
else:
ext = "so"
try:
session.run_sql("INSTALL PLUGIN {0} SONAME '{1}.{2}';".format(
plugin_name, plugin_soname, ext))
except mysqlsh.DBError as e:
if 1125 != e.code:
raise e
def ensure_plugin_disabled(plugin_name, session):
is_installed = session.run_sql(
"SELECT COUNT(1) FROM INFORMATION_SCHEMA.PLUGINS WHERE PLUGIN_NAME LIKE '" + plugin_name + "';").fetch_one()[0]
if is_installed:
session.run_sql("UNINSTALL PLUGIN " + plugin_name + ";")
# Starting 8.0.24 the client lib started reporting connection error using
# host:port format, previous versions used just the host.
#
# This function is used to format the host description accordingly.
def libmysql_host_description(hostname, port):
if testutil.version_check(__libmysql_version_id, ">", "8.0.23"):
return hostname + ":" + str(port)
return hostname
def get_socket_path(session, uri=None):
if uri is None:
uri = session.uri
row = session.run_sql(
f"SELECT @@{'socket' if 'mysql' == shell.parse_uri(uri).scheme else 'mysqlx_socket'}, @@datadir").fetch_one()
if row[0][0] == '/' or __os_type == "windows":
p = row[0]
else:
p = os.path.join(row[1], row[0])
if len(p) > 100:
testutil.fail("socket path is too long (>100): " + p)
return p
def get_socket_uri(session, uri=None):
if uri is None:
uri = session.uri
parsed = shell.parse_uri(uri)
if "password" not in parsed:
parsed["password"] = ""
new_uri = {}
for key in ["scheme", "user", "password"]:
new_uri[key] = parsed[key]
new_uri["socket"] = get_socket_path(session, uri)
return shell.unparse_uri(new_uri)
def random_string(lower, upper=None):
if upper is None:
upper = lower
return ''.join(random.choices(string.ascii_letters + string.digits, k=random.randint(lower, upper)))
def random_email():
return random_string(10, 40) + "@" + random_string(10, 40) + "." + random_string(3)
def md5sum(s):
return hashlib.md5(s.encode("utf-8")).hexdigest()
def reset_instance(session):
try:
session.run_sql("SELECT group_replication_reset_member_actions()")
except:
pass
try:
session.run_sql("SELECT asynchronous_connection_failover_reset()")
except:
pass
session.run_sql("STOP SLAVE")
try:
session.run_sql("STOP group_replication")
session.run_sql("SET PERSIST group_replication_start_on_boot=0")
except:
pass
session.run_sql("SET GLOBAL super_read_only=0")
session.run_sql("SET GLOBAL read_only=0")
session.run_sql("DROP SCHEMA IF EXISTS mysql_innodb_cluster_metadata")
r = session.run_sql("SHOW SCHEMAS")
rows = r.fetch_all()
for row in rows:
if row[0] in ["mysql", "performance_schema", "sys", "information_schema"]:
continue
session.run_sql("DROP SCHEMA "+row[0])
r = session.run_sql("SELECT user,host FROM mysql.user")
rows = r.fetch_all()
for row in rows:
if row[0] in ["mysql.sys", "mysql.session", "mysql.infoschema"]:
continue
if row[0] == "root" and (row[1] == "localhost" or row[1] == "%"):
continue
session.run_sql("DROP USER ?@?", [row[0], row[1]])
session.run_sql("RESET MASTER")
session.run_sql("RESET SLAVE ALL")
def reset_multi(ports):
testutil.stop_group(ports)
for p in ports:
s = mysql.get_session(f"mysql://root:root@localhost:{p}")
reset_instance(s)
s.close()
class Docker_manipulator:
def __init__(self, docker_name, data_path):
self.data_path = data_path
self.dockerfile = os.path.join(
self.data_path, docker_name, "Dockerfile")
self.docker_name = docker_name
self.docker_image_name = "mysql-sshd-shell-{}".format(docker_name)
if not os.path.exists(self.dockerfile):
testutil.fail(
"The Dockerfile is missing for {}".format(docker_name))
self.hash = self.__get_hash()
import docker
self.client = docker.from_env()
self.container = None
def __get_hash(self):
import hashlib
with open(self.dockerfile, "rb") as f:
return hashlib.sha256(f.read()).hexdigest()
def __pre_run_check(self):
# before runnig it check if there's already such container if so, kill it and start again
try:
print("About to find old container")
c = self.client.containers.get(self.docker_name)
if c.status == "running":
c.kill()
self.client.containers.prune()
print("Old container found and removed")
except:
# this should be fine as containers.get will throw when container is not found
pass
try:
print("Checking image tag")
self.img = self.client.images.get(self.docker_image_name)
if self.img.attrs.get("ContainerConfig").get("Labels").get("hash") != self.hash:
self.client.images.remove(image=self.img.id, force=True)
print("Old image found which is outdated, removed")
except Exception as e:
pass
def __build_and_run(self):
try:
self.img
except:
self.img = self.client.images.build(path=os.path.join(
self.data_path, self.docker_name), tag=self.docker_image_name,
labels={"hash": self.hash})[0]
print("Build new docker image")
self.container = self.client.containers.run(image=self.img.id, detach=True, name=self.docker_name,
environment=["MYSQL_ROOT_PASSWORD=sandbox"], ports={"2222/tcp": 2222})
print("Started new container")
def log_watcher(self, log_fetcher):
while True:
try:
self._queue.put_nowait(log_fetcher.next())
except:
break
def run(self):
self._queue = queue.Queue(10)
self.__pre_run_check()
self.__build_and_run()
log_iter = self.container.logs(stream=True)
watcher = threading.Thread(target=self.log_watcher, args=(log_iter,))
watcher.start()
while True:
if self.client.containers.get(self.container.id).status == "running":
break
else:
time.sleep(1)
while True:
try:
line = self._queue.get(block=True, timeout=10)
match = re.search(
br"port: ([0-9]{4})\s MySQL Community Server - GPL", line)
if match:
print("Found port: {:g}".format(match.groups()[0]))
log_iter.close()
break
except Exception as e:
break
print("Waiting for container to fully start")
log_iter.close()
watcher.join()
print("Container is ready")
def cleanup(self, remove_image=False):
if self.container is not None:
self.container.kill() # kill the container
# the python client.images.build leave a mess, clean it
self.client.containers.prune()
if remove_image:
self.client.images.remove(image=self.img.id, force=True)