aws_advanced_python_wrapper/default_plugin.py (103 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from aws_advanced_python_wrapper.connection_provider import (ConnectionProvider,
ConnectionProviderManager)
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.host_list_provider import HostListProviderService
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.utils.properties import Properties
import copy
from typing import Any, Callable, Set
from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.host_availability import HostAvailability
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.plugin import Plugin
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
TelemetryTraceLevel
class DefaultPlugin(Plugin):
_SUBSCRIBED_METHODS: Set[str] = {"*"}
_CLOSE_METHOD = "Connection.close"
def __init__(self, plugin_service: PluginService, connection_provider_manager: ConnectionProviderManager):
self._plugin_service: PluginService = plugin_service
self._connection_provider_manager = connection_provider_manager
def connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
target_driver_props = copy.copy(props)
connection_provider: ConnectionProvider = \
self._connection_provider_manager.get_connection_provider(host_info, target_driver_props)
result = self._connect(target_driver_func, driver_dialect, host_info, target_driver_props, connection_provider)
return result
def _connect(
self,
target_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
conn_provider: ConnectionProvider) -> Connection:
telemetry_factory = self._plugin_service.get_telemetry_factory()
context = telemetry_factory.open_telemetry_context(driver_dialect.driver_name, TelemetryTraceLevel.NESTED)
conn: Connection
try:
database_dialect = self._plugin_service.database_dialect
conn = conn_provider.connect(target_func, driver_dialect, database_dialect, host_info, props)
finally:
context.close_context()
self._plugin_service.set_availability(host_info.all_aliases, HostAvailability.AVAILABLE)
self._plugin_service.update_driver_dialect(conn_provider)
self._plugin_service.update_dialect(conn)
return conn
def force_connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
force_connect_func: Callable) -> Connection:
target_driver_props = copy.copy(props)
return self._connect(
target_driver_func,
driver_dialect,
host_info,
target_driver_props,
self._connection_provider_manager.default_provider)
def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any:
telemetry_factory = self._plugin_service.get_telemetry_factory()
context = telemetry_factory.open_telemetry_context(
self._plugin_service.driver_dialect.driver_name, TelemetryTraceLevel.NESTED)
try:
result = self._plugin_service.driver_dialect.execute(method_name, execute_func, *args, **kwargs)
finally:
context.close_context()
if method_name != DefaultPlugin._CLOSE_METHOD and self._plugin_service.current_connection is not None:
self._plugin_service.update_in_transaction()
return result
def accepts_strategy(self, role: HostRole, strategy: str) -> bool:
if HostRole.UNKNOWN == role:
return False
return self._connection_provider_manager.accepts_strategy(role, strategy)
def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo:
if HostRole.UNKNOWN == role:
raise AwsWrapperError(Messages.get("DefaultPlugin.UnknownHosts"))
hosts = self._plugin_service.hosts
if len(hosts) < 1:
raise AwsWrapperError(Messages.get("DefaultPlugin.EmptyHosts"))
return self._connection_provider_manager.get_host_info_by_strategy(hosts, role, strategy, self._plugin_service.props)
@property
def subscribed_methods(self) -> Set[str]:
return DefaultPlugin._SUBSCRIBED_METHODS
def init_host_provider(
self,
props: Properties,
host_list_provider_service: HostListProviderService,
init_host_provider_func: Callable):
# Do nothing
# This is the last plugin in the plugin chain.
# So init_host_provider_func will be a no-op and does not need to be called.
pass