aws_advanced_python_wrapper/driver_dialect_manager.py (88 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Dict, Optional, Protocol, Union
if TYPE_CHECKING:
from aws_advanced_python_wrapper.connection_provider import ConnectionProvider
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.utils import Utils
logger = Logger(__name__)
class DriverDialectProvider(Protocol):
def get_dialect(self, conn_func: Callable, props: Properties) -> DriverDialect:
...
def get_pool_connection_driver_dialect(
self,
connection_provider: ConnectionProvider,
underlying_driver_dialect: DriverDialect,
props: Properties) -> DriverDialect:
...
class DriverDialectManager(DriverDialectProvider):
_custom_dialect: Optional[DriverDialect] = None
known_dialects_by_code: Dict[str, str] = {
DriverDialectCodes.PSYCOPG: "aws_advanced_python_wrapper.pg_driver_dialect.PgDriverDialect",
DriverDialectCodes.MYSQL_CONNECTOR_PYTHON: "aws_advanced_python_wrapper.mysql_driver_dialect.MySQLDriverDialect",
DriverDialectCodes.GENERIC: "aws_advanced_python_wrapper.generic_driver_dialect.GenericDriverDialect",
}
pool_connection_driver_dialect: Dict[str, str] = {
"SqlAlchemyPooledConnectionProvider": "aws_advanced_python_wrapper.sqlalchemy_driver_dialect.SqlAlchemyDriverDialect"
}
@staticmethod
def get_custom_dialect():
return DriverDialectManager._custom_dialect
@staticmethod
def set_custom_dialect(dialect: DriverDialect):
DriverDialectManager._custom_dialect = dialect
@staticmethod
def reset_custom_dialect():
DriverDialectManager._custom_dialect = None
def get_dialect(self, conn_func: Callable, props: Properties) -> DriverDialect:
if self._custom_dialect is not None:
if self._custom_dialect.is_dialect(conn_func):
self._log_dialect("custom", self._custom_dialect)
return self._custom_dialect
else:
logger.warning("DriverDialectManager.CustomDialectNotSupported")
result: Optional[str]
dialect_code: Optional[str] = WrapperProperties.DRIVER_DIALECT.get(props)
if dialect_code:
result = DriverDialectManager.known_dialects_by_code.get(dialect_code)
if result is None:
raise AwsWrapperError(Messages.get_formatted(
"DriverDialectManager.UnknownDialectCode",
dialect_code))
self._log_dialect(dialect_code, result)
dialect = Utils.initialize_class(result, props)
if dialect is None:
raise AwsWrapperError(Messages.get_formatted(
"DriverDialectManager.InitializationError",
dialect_code))
return dialect
for key, value in DriverDialectManager.known_dialects_by_code.items():
dialect = Utils.initialize_class(value, props)
if dialect is not None and dialect.is_dialect(conn_func):
self._log_dialect(key, value)
return dialect
self._log_dialect(DriverDialectCodes.GENERIC, "generic")
return DriverDialect(props)
@staticmethod
def _log_dialect(dialect_code: str, driver_dialect: Union[DriverDialect, str]):
logger.debug(
"DriverDialectManager.UseDialect",
dialect_code,
driver_dialect)
def get_pool_connection_driver_dialect(
self,
connection_provider: ConnectionProvider,
underlying_driver_dialect: DriverDialect,
props: Properties) -> DriverDialect:
provider_class: str = connection_provider.__class__.__name__
pool_connection_driver_dialect = self.pool_connection_driver_dialect.get(provider_class)
if pool_connection_driver_dialect is not None:
dialect = Utils.initialize_class(pool_connection_driver_dialect, underlying_driver_dialect, props)
if dialect is not None:
return dialect
return underlying_driver_dialect