mysqlx-connector-python/lib/mysqlx/helpers.py (85 lines of code) (raw):
# Copyright (c) 2017, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""This module contains helper functions."""
import binascii
import decimal
import functools
import inspect
import warnings
from typing import Any, Callable, List, Optional, Union
from .constants import SUPPORTED_TLS_VERSIONS, TLS_CIPHER_SUITES
from .errors import InterfaceError
from .types import EscapeTypes, StrOrBytes
BYTE_TYPES = (bytearray, bytes)
NUMERIC_TYPES = (int, float, decimal.Decimal)
def encode_to_bytes(value: StrOrBytes, encoding: str = "utf-8") -> bytes:
"""Returns an encoded version of the string as a bytes object.
Args:
encoding (str): The encoding.
Resturns:
bytes: The encoded version of the string as a bytes object.
"""
return value if isinstance(value, bytes) else value.encode(encoding)
def decode_from_bytes(value: StrOrBytes, encoding: str = "utf-8") -> str:
"""Returns a string decoded from the given bytes.
Args:
value (bytes): The value to be decoded.
encoding (str): The encoding.
Returns:
str: The value decoded from bytes.
"""
return value.decode(encoding) if isinstance(value, bytes) else value
def get_item_or_attr(obj: object, key: str) -> Any:
"""Get item from dictionary or attribute from object.
Args:
obj (object): Dictionary or object.
key (str): Key.
Returns:
object: The object for the provided key.
"""
return obj[key] if isinstance(obj, dict) else getattr(obj, key)
def escape(*args: EscapeTypes) -> Union[EscapeTypes, List[EscapeTypes]]:
"""Escapes special characters as they are expected to be when MySQL
receives them.
As found in MySQL source mysys/charset.c
Args:
value (object): Value to be escaped.
Returns:
str: The value if not a string, or the escaped string.
"""
def _escape(value: EscapeTypes) -> EscapeTypes:
"""Escapes special characters."""
if value is None:
return value
if isinstance(value, NUMERIC_TYPES):
return value
if isinstance(value, (bytes, bytearray)):
value = value.replace(b"\\", b"\\\\")
value = value.replace(b"\n", b"\\n")
value = value.replace(b"\r", b"\\r")
value = value.replace(b"\047", b"\134\047") # single quotes
value = value.replace(b"\042", b"\134\042") # double quotes
value = value.replace(b"\032", b"\134\032") # for Win32
else:
value = value.replace("\\", "\\\\")
value = value.replace("\n", "\\n")
value = value.replace("\r", "\\r")
value = value.replace("\047", "\134\047") # single quotes
value = value.replace("\042", "\134\042") # double quotes
value = value.replace("\032", "\134\032") # for Win32
return value
if len(args) > 1:
return [_escape(arg) for arg in args]
return _escape(args[0])
def quote_identifier(identifier: str, sql_mode: str = "") -> str:
"""Quote the given identifier with backticks, converting backticks (`)
in the identifier name with the correct escape sequence (``) unless the
identifier is quoted (") as in sql_mode set to ANSI_QUOTES.
Args:
identifier (str): Identifier to quote.
Returns:
str: Returns string with the identifier quoted with backticks.
"""
if sql_mode == "ANSI_QUOTES":
quoted = identifier.replace('"', '""')
return f'"{quoted}"'
quoted = identifier.replace("`", "``")
return f"`{quoted}`"
def deprecated(version: Optional[str] = None, reason: Optional[str] = None) -> Callable:
"""This is a decorator used to mark functions as deprecated.
Args:
version (Optional[string]): Version when was deprecated.
reason (Optional[string]): Reason or extra information to be shown.
Returns:
Callable: A decorator used to mark functions as deprecated.
Usage:
.. code-block:: python
from mysqlx.helpers import deprecated
@deprecated('8.0.12', 'Please use other_function() instead')
def deprecated_function(x, y):
return x + y
"""
def decorate(func: Callable) -> Callable:
"""Decorate function."""
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Callable:
"""Wrapper function.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
message = [f"'{func.__name__}' is deprecated"]
if version:
message.append(f" since version {version}")
if reason:
message.append(f". {reason}")
frame = inspect.currentframe().f_back
warnings.warn_explicit(
"".join(message),
category=DeprecationWarning,
filename=inspect.getfile(frame.f_code),
lineno=frame.f_lineno,
)
return func(*args, **kwargs)
return wrapper
return decorate
def iani_to_openssl_cs_name(
tls_version: str, cipher_suites_names: List[str]
) -> List[str]:
"""Translates a cipher suites names list; from IANI names to OpenSSL names.
Args:
TLS_version (str): The TLS version to look at for a translation.
cipher_suite_names (list): A list of cipher suites names.
Returns:
List[str]: List of translated names.
"""
translated_names = []
cipher_suites = {} # TLS_CIPHER_SUITES[TLS_version]
# Find the previews TLS versions of the given on TLS_version
for index in range(SUPPORTED_TLS_VERSIONS.index(tls_version) + 1):
cipher_suites.update(TLS_CIPHER_SUITES[SUPPORTED_TLS_VERSIONS[index]])
for name in cipher_suites_names:
if "-" in name:
translated_names.append(name)
elif name in cipher_suites:
translated_names.append(cipher_suites[name])
else:
raise InterfaceError(
f"The '{name}' in cipher suites is not a valid cipher suite"
)
return translated_names
def hexlify(data: bytes) -> str:
"""Return the hexadecimal representation of the binary data.
Args:
data (bytes): The binary data.
Returns:
str: The decoded hexadecimal representation of data.
"""
return binascii.hexlify(data).decode("utf-8")