aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py (152 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from threading import Thread from typing import (TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Optional, Set, Tuple) if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.plugin_service import PluginService from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType from aws_advanced_python_wrapper.utils.properties import Properties from _weakrefset import WeakSet from aws_advanced_python_wrapper.errors import FailoverError from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils logger = Logger(__name__) class OpenedConnectionTracker: _opened_connections: Dict[str, WeakSet] = {} _rds_utils = RdsUtils() def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection): """ Add the given connection to the set of tracked connections. :param host_info: host information of the given connection. :param conn: currently opened connection. """ aliases: FrozenSet[str] = host_info.as_aliases() if self._rds_utils.is_rds_instance(host_info.host): self._track_connection(host_info.as_alias(), conn) return instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), None) if not instance_endpoint: logger.debug("OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet") return self._track_connection(instance_endpoint, conn) def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: Optional[FrozenSet[str]] = None): """ Invalidates all opened connections pointing to the same host in a daemon thread. :param host_info: the :py:class:`HostInfo` object containing the URL of the host. :param host: the set of aliases representing a specific host. """ if host_info: self.invalidate_all_connections(host=frozenset(host_info.as_alias())) self.invalidate_all_connections(host=host_info.as_aliases()) return instance_endpoint: Optional[str] = None if host is None: return for instance in host: if instance is not None and self._rds_utils.is_rds_instance(self._rds_utils.remove_port(instance)): instance_endpoint = instance break if not instance_endpoint: return connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint) if connection_set is not None: self._log_connection_set(instance_endpoint, connection_set) self._invalidate_connections(connection_set) def _track_connection(self, instance_endpoint: str, conn: Connection): connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint) if connection_set is None: connection_set = WeakSet() connection_set.add(conn) self._opened_connections[instance_endpoint] = connection_set else: connection_set.add(conn) self.log_opened_connections() @staticmethod def _task(connection_set: WeakSet): while connection_set is not None and len(connection_set) > 0: conn_reference = connection_set.pop() if conn_reference is None: continue try: conn_reference.close() except Exception: # Swallow this exception, current connection should be useless anyway pass def _invalidate_connections(self, connection_set: WeakSet): invalidate_connection_thread: Thread = Thread(daemon=True, target=self._task, args=[connection_set]) # type: ignore invalidate_connection_thread.start() def log_opened_connections(self): msg = "" for key, conn_set in self._opened_connections.items(): conn = "" for item in list(conn_set): conn += f"\n\t\t{item}" msg += f"\t[{key} : {conn}]" return logger.debug("OpenedConnectionTracker.OpenedConnectionsTracked", msg) def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]): if conn_set is None or len(conn_set) == 0: return conn = "" for item in list(conn_set): conn += f"\n\t\t{item}" msg = host + f"[{conn}\n]" logger.debug("OpenedConnectionTracker.InvalidatingConnections", msg) class AuroraConnectionTrackerPlugin(Plugin): _SUBSCRIBED_METHODS: Set[str] = {"*"} _current_writer: Optional[HostInfo] = None _need_update_current_writer: bool = False @property def subscribed_methods(self) -> Set[str]: return self._SUBSCRIBED_METHODS def __init__(self, plugin_service: PluginService, props: Properties, rds_utils: RdsUtils = RdsUtils(), tracker: OpenedConnectionTracker = OpenedConnectionTracker()): self._plugin_service = plugin_service self._props = props self._rds_utils = rds_utils self._tracker = tracker def connect( self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Connection: return self._connect(host_info, connect_func) def force_connect( self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, is_initial_connection: bool, force_connect_func: Callable) -> Connection: return self._connect(host_info, force_connect_func) def _connect(self, host_info: HostInfo, connect_func: Callable): conn = connect_func() if conn: url_type: RdsUrlType = self._rds_utils.identify_rds_type(host_info.host) if url_type.is_rds_cluster: host_info.reset_aliases() self._plugin_service.fill_aliases(conn, host_info) self._tracker.populate_opened_connection_set(host_info, conn) self._tracker.log_opened_connections() return conn def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: if self._current_writer is None or self._need_update_current_writer: self._current_writer = self._get_writer(self._plugin_service.all_hosts) self._need_update_current_writer = False try: return execute_func() except Exception as e: # Check that e is a FailoverError and that the writer has changed if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.all_hosts) != self._current_writer: self._tracker.invalidate_all_connections(host_info=self._current_writer) self._tracker.log_opened_connections() self._need_update_current_writer = True raise e def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]: for host in hosts: if host.role == HostRole.WRITER: return host return None class AuroraConnectionTrackerPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: return AuroraConnectionTrackerPlugin(plugin_service, props)