in ossdbtoolsservice/driver/types/pymysql_driver.py [0:0]
def __init__(self, conn_params: {}, config: Optional[Configuration] = None):
"""
Creates a new connection wrapper. Parses version string
:param conn_params: connection parameters dict
:param config: optional Configuration object with mysql connection config
"""
if 'azureAccountToken' in conn_params:
conn_params['password'] = conn_params['azureAccountToken']
# Map the provided connection parameter names to pymysql param names
_params = {MYSQL_CONNECTION_OPTION_KEY_MAP.get(param, param): value for param, value in conn_params.items()}
# Filter the parameters to only those accepted by PyMySQL
self._connection_options = {param: value for param, value in _params.items() if param in MYSQL_CONNECTION_PARAM_KEYWORDS}
# Convert the numeric params from strings to integers
numeric_params = ["port", "connect_timeout", "read_timeout", "write_timeout"]
for param in numeric_params:
if param in self._connection_options.keys():
val = self._connection_options[param]
if val:
self._connection_options[param] = int(val) or None
# Use the default database if one was not provided
if 'database' not in self._connection_options or not self._connection_options['database']:
if config:
self._connection_options['database'] = config.my_sql.default_database
# Use the default port number if one was not provided
if 'port' not in self._connection_options or not self._connection_options['port']:
self._connection_options['port'] = constants.DEFAULT_PORT[constants.MYSQL_PROVIDER_NAME]
# If SSL is enabled or allowed
if "ssl" in conn_params.keys() and self._connection_options["ssl"] != "disable":
# Find all the ssl options (key, ca, cipher)
ssl_params = {param for param in conn_params if param.startswith("ssl.")}
# Map the ssl option names to their values
ssl_dict = {param.strip("ssl."): conn_params[param] for param in ssl_params}
# Assign the ssl options to the dict
self._connection_options["ssl"] = ssl_dict
# Setting autocommit to True initally
self._autocommit_status = True
# Pass connection parameters as keyword arguments to the connection by unpacking the connection_options dict
self._conn = pymysql.connect(**self._connection_options)
self._connection_closed = False
# Find the class of the database error this driver throws
self._database_error = pymysql.err.DatabaseError
# Calculate the server version
# Source: https://stackoverflow.com/questions/8987679/how-to-retrieve-the-current-version-of-a-mysql-database
version_string = self.execute_query("SELECT VERSION();")[0][0]
# Split the different components of the version string
version_components: List = re.split(r"[.-]", version_string)
self._version: Tuple[int, int, int] = (
int(version_components[0]),
int(version_components[1]),
int(version_components[2])
)
self._provider_name = constants.MYSQL_PROVIDER_NAME
# Find what type of server we have connected to
if len(version_components) == 4 and version_components[3] == "MariaDB":
self._server_type = "MariaDB"
else:
self._server_type = "MySQL"