aws_advanced_python_wrapper/wrapper.py (243 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import (TYPE_CHECKING, Any, Callable, Iterator, List, Optional,
Union)
if TYPE_CHECKING:
from aws_advanced_python_wrapper.host_list_provider import HostListProviderService
from aws_advanced_python_wrapper.driver_dialect_manager import \
DriverDialectManager
from aws_advanced_python_wrapper.errors import (AwsWrapperError,
FailoverSuccessError)
from aws_advanced_python_wrapper.pep249 import Connection, Cursor, Error
from aws_advanced_python_wrapper.plugin import CanReleaseResources
from aws_advanced_python_wrapper.plugin_service import (
PluginManager, PluginService, PluginServiceImpl,
PluginServiceManagerContainer)
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
PropertiesUtils)
from aws_advanced_python_wrapper.utils.telemetry.default_telemetry_factory import \
DefaultTelemetryFactory
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
TelemetryTraceLevel
logger = Logger(__name__)
class AwsWrapperConnection(Connection, CanReleaseResources):
__module__ = "aws_advanced_python_wrapper"
def __init__(
self,
target_func: Callable,
host_list_provider_service: HostListProviderService,
plugin_service: PluginService,
plugin_manager: PluginManager):
self._plugin_service = plugin_service
self._plugin_manager = plugin_manager
host_list_provider_init = plugin_service.database_dialect.get_host_list_provider_supplier()
plugin_service.host_list_provider = host_list_provider_init(host_list_provider_service, plugin_service.props)
plugin_manager.init_host_provider(plugin_service.props, host_list_provider_service)
plugin_service.refresh_host_list()
if plugin_service.current_connection is not None:
return
if plugin_service.initial_connection_host_info is None:
raise AwsWrapperError(Messages.get("AwsWrapperConnection.InitialHostInfoNone"))
conn = plugin_manager.connect(
target_func,
plugin_service.driver_dialect,
plugin_service.initial_connection_host_info,
plugin_service.props,
True)
if not conn:
raise AwsWrapperError(Messages.get("AwsWrapperConnection.ConnectionNotOpen"))
plugin_service.set_current_connection(conn, plugin_service.initial_connection_host_info)
@property
def target_connection(self):
return self._plugin_service.current_connection
@property
def is_closed(self):
return self._plugin_service.driver_dialect.is_closed(self.target_connection)
@property
def read_only(self) -> bool:
return self._plugin_manager.execute(
self.target_connection,
"Connection.is_read_only",
lambda: self._is_read_only())
@read_only.setter
def read_only(self, val: bool):
self._plugin_manager.execute(
self.target_connection,
"Connection.set_read_only",
lambda: self._set_read_only(val),
val)
def _is_read_only(self) -> bool:
is_read_only = self._plugin_service.driver_dialect.is_read_only(self.target_connection)
self._plugin_service.session_state_service.setup_pristine_readonly(is_read_only)
return is_read_only
def _set_read_only(self, val: bool):
self._plugin_service.session_state_service.setup_pristine_readonly(val)
self._plugin_service.driver_dialect.set_read_only(self.target_connection, val)
self._plugin_service.session_state_service.set_read_only(val)
@property
def autocommit(self):
return self._plugin_manager.execute(
self.target_connection,
"Connection.autocommit",
lambda: self._plugin_service.driver_dialect.get_autocommit(self.target_connection))
@autocommit.setter
def autocommit(self, val: bool):
self._plugin_manager.execute(
self.target_connection,
"Connection.autocommit_setter",
lambda: self._set_autocommit(val),
val)
def _get_autocommit(self) -> bool:
autocommit = self._plugin_service.driver_dialect.get_autocommit(self.target_connection)
self._plugin_service.session_state_service.setup_pristine_autocommit(autocommit)
return autocommit
def _set_autocommit(self, val: bool):
self._plugin_service.session_state_service.setup_pristine_autocommit(val)
self._plugin_service.driver_dialect.set_autocommit(self.target_connection, val)
self._plugin_service.session_state_service.set_autocommit(val)
@staticmethod
def connect(
target: Union[None, str, Callable] = None,
conninfo: str = "",
*args: Any,
**kwargs: Any) -> AwsWrapperConnection:
if not target:
raise Error(Messages.get("Wrapper.RequiredTargetDriver"))
if not callable(target):
raise Error(Messages.get("Wrapper.ConnectMethod"))
target_func: Callable = target
props: Properties = PropertiesUtils.parse_properties(conn_info=conninfo, **kwargs)
logger.debug("Wrapper.Properties", PropertiesUtils.log_properties(PropertiesUtils.mask_properties(props)))
telemetry_factory = DefaultTelemetryFactory(props)
context = telemetry_factory.open_telemetry_context(__name__, TelemetryTraceLevel.TOP_LEVEL)
try:
driver_dialect_manager: DriverDialectManager = DriverDialectManager()
driver_dialect = driver_dialect_manager.get_dialect(target_func, props)
container: PluginServiceManagerContainer = PluginServiceManagerContainer()
plugin_service = PluginServiceImpl(
container, props, target_func, driver_dialect_manager, driver_dialect)
plugin_manager: PluginManager = PluginManager(container, props, telemetry_factory)
return AwsWrapperConnection(target_func, plugin_service, plugin_service, plugin_manager)
except Exception as ex:
context.set_exception(ex)
context.set_success(False)
raise ex
finally:
context.close_context()
def close(self) -> None:
self._plugin_manager.execute(self.target_connection, "Connection.close",
lambda: self.target_connection.close())
def cursor(self, *args: Any, **kwargs: Any) -> AwsWrapperCursor:
_cursor = self._plugin_manager.execute(self.target_connection, "Connection.cursor",
lambda: self.target_connection.cursor(*args, **kwargs),
*args, **kwargs)
return AwsWrapperCursor(self, self._plugin_service, self._plugin_manager, _cursor)
def commit(self) -> None:
self._plugin_manager.execute(self.target_connection, "Connection.commit",
lambda: self.target_connection.commit())
def rollback(self) -> None:
self._plugin_manager.execute(self.target_connection, "Connection.rollback",
lambda: self.target_connection.rollback())
def tpc_begin(self, xid: Any) -> None:
self._plugin_manager.execute(self.target_connection, "Connection.tpc_begin",
lambda: self.target_connection.tpc_begin(xid), xid)
def tpc_prepare(self) -> None:
self._plugin_manager.execute(self.target_connection, "Connection.tpc_prepare",
lambda: self.target_connection.tpc_prepare())
def tpc_commit(self, xid: Any = None) -> None:
self._plugin_manager.execute(self.target_connection, "Connection.tpc_commit",
lambda: self.target_connection.tpc_commit(xid), xid)
def tpc_rollback(self, xid: Any = None) -> None:
self._plugin_manager.execute(self.target_connection, "Connection.tpc_rollback",
lambda: self.target_connection.tpc_rollback(xid), xid)
def tpc_recover(self) -> Any:
return self._plugin_manager.execute(self.target_connection, "Connection.tpc_recover",
lambda: self.target_connection.tpc_recover())
def release_resources(self):
self._plugin_manager.release_resources()
if isinstance(self._plugin_service, CanReleaseResources):
self._plugin_service.release_resources()
def __del__(self):
self.release_resources()
def __enter__(self: AwsWrapperConnection) -> AwsWrapperConnection:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self._plugin_manager.execute(self.target_connection, "Connection.close",
lambda: self.target_connection.close(), exc_type, exc_val, exc_tb)
class AwsWrapperCursor(Cursor):
__module__ = "aws_advanced_python_wrapper"
def __init__(
self,
conn: AwsWrapperConnection,
plugin_service: PluginService,
plugin_manager: PluginManager,
target_cursor: Cursor):
self._conn: AwsWrapperConnection = conn
self._plugin_service: PluginService = plugin_service
self._plugin_manager: PluginManager = plugin_manager
self._target_cursor: Cursor = target_cursor
# It's not part of PEP249
@property
def connection(self) -> AwsWrapperConnection:
return self._conn
@property
def target_cursor(self) -> Cursor:
return self._target_cursor
@property
def description(self):
return self.target_cursor.description
@property
def rowcount(self) -> int:
return self.target_cursor.rowcount
@property
def arraysize(self) -> int:
return self.target_cursor.arraysize
def close(self) -> None:
self._plugin_manager.execute(self.target_cursor, "Cursor.close",
lambda: self.target_cursor.close())
def callproc(self, *args: Any, **kwargs: Any):
return self._plugin_manager.execute(self.target_cursor, "Cursor.callproc",
lambda: self.target_cursor.callproc(**kwargs), *args, **kwargs)
def execute(
self,
*args: Any,
**kwargs: Any
) -> AwsWrapperCursor:
try:
return self._plugin_manager.execute(self.target_cursor, "Cursor.execute",
lambda: self.target_cursor.execute(*args, **kwargs), *args, **kwargs)
except FailoverSuccessError as e:
self._target_cursor = self.connection.target_connection.cursor()
raise e
def executemany(
self,
*args: Any,
**kwargs: Any
) -> None:
self._plugin_manager.execute(self.target_cursor, "Cursor.executemany",
lambda: self.target_cursor.executemany(*args, **kwargs),
*args, **kwargs)
def nextset(self) -> bool:
return self._plugin_manager.execute(self.target_cursor, "Cursor.nextset",
lambda: self.target_cursor.nextset())
def fetchone(self) -> Any:
return self._plugin_manager.execute(self.target_cursor, "Cursor.fetchone",
lambda: self.target_cursor.fetchone())
def fetchmany(self, size: int = 0) -> List[Any]:
return self._plugin_manager.execute(self.target_cursor, "Cursor.fetchmany",
lambda: self.target_cursor.fetchmany(size), size)
def fetchall(self) -> List[Any]:
return self._plugin_manager.execute(self.target_cursor, "Cursor.fetchall",
lambda: self.target_cursor.fetchall())
def __iter__(self) -> Iterator[Any]:
return self.target_cursor.__iter__()
def setinputsizes(self, sizes: Any) -> None:
return self._plugin_manager.execute(self.target_cursor, "Cursor.setinputsizes",
lambda: self.target_cursor.setinputsizes(sizes), sizes)
def setoutputsize(self, size: Any, column: Optional[int] = None) -> None:
return self._plugin_manager.execute(self.target_cursor, "Cursor.setoutputsize",
lambda: self.target_cursor.setoutputsize(size, column), size, column)
def __enter__(self: AwsWrapperCursor) -> AwsWrapperCursor:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()