aws_advanced_python_wrapper/read_write_splitting_plugin.py (284 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable, Optional, Set, Tuple if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.host_list_provider import HostListProviderService from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService from aws_advanced_python_wrapper.utils.properties import Properties from aws_advanced_python_wrapper.connection_provider import ConnectionProviderManager from aws_advanced_python_wrapper.errors import (AwsWrapperError, FailoverError, ReadWriteSplittingError) from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.notifications import ( ConnectionEvent, OldConnectionSuggestedAction) from aws_advanced_python_wrapper.utils.properties import WrapperProperties logger = Logger(__name__) class ReadWriteSplittingPlugin(Plugin): _SUBSCRIBED_METHODS: Set[str] = {"init_host_provider", "connect", "notify_connection_changed", "Connection.set_read_only"} _POOL_PROVIDER_CLASS_NAME = "aws_advanced_python_wrapper.sql_alchemy_connection_provider.SqlAlchemyPooledConnectionProvider" def __init__(self, plugin_service: PluginService, props: Properties): self._plugin_service = plugin_service self._properties = props self._host_list_provider_service: HostListProviderService self._writer_connection: Optional[Connection] = None self._reader_connection: Optional[Connection] = None self._reader_host_info: Optional[HostInfo] = None self._conn_provider_manager: ConnectionProviderManager = self._plugin_service.get_connection_provider_manager() self._is_reader_conn_from_internal_pool: bool = False self._is_writer_conn_from_internal_pool: bool = False self._in_read_write_split: bool = False self._reader_selector_strategy: str = "" strategy = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(self._properties) if strategy is not None: self._reader_selector_strategy = strategy else: default_strategy = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.default_value if default_strategy is not None: self._reader_selector_strategy = default_strategy @property def subscribed_methods(self) -> Set[str]: return self._SUBSCRIBED_METHODS def init_host_provider( self, props: Properties, host_list_provider_service: HostListProviderService, init_host_provider_func: Callable): self._host_list_provider_service = host_list_provider_service init_host_provider_func() def connect( self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Connection: if not self._plugin_service.accepts_strategy(host_info.role, self._reader_selector_strategy): raise AwsWrapperError( Messages.get_formatted("ReadWriteSplittingPlugin.UnsupportedHostInfoSelectorStrategy", self._reader_selector_strategy)) return self.connect_internal(is_initial_connection, connect_func) def force_connect( self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, is_initial_connection: bool, force_connect_func: Callable) -> Connection: return self.connect_internal(is_initial_connection, force_connect_func) def connect_internal(self, is_initial_connection: bool, connect_func: Callable) -> Connection: current_conn = connect_func() if not is_initial_connection or self._host_list_provider_service.is_static_host_list_provider(): return current_conn current_role = self._plugin_service.get_host_role(current_conn) if current_role is None or current_role == HostRole.UNKNOWN: self._log_and_raise_exception("ReadWriteSplittingPlugin.ErrorVerifyingInitialHostSpecRole") current_host = self._plugin_service.initial_connection_host_info if current_host is not None: if current_role == current_host.role: return current_conn updated_host = deepcopy(current_host) updated_host.role = current_role self._host_list_provider_service.initial_connection_host_info = updated_host return current_conn def notify_connection_changed(self, changes: Set[ConnectionEvent]) -> OldConnectionSuggestedAction: self._update_internal_connection_info() if self._in_read_write_split: return OldConnectionSuggestedAction.PRESERVE return OldConnectionSuggestedAction.NO_OPINION def execute(self, target: type, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: driver_dialect = self._plugin_service.driver_dialect conn: Optional[Connection] = driver_dialect.get_connection_from_obj(target) current_conn: Optional[Connection] = driver_dialect.unwrap_connection(self._plugin_service.current_connection) if conn is not None and conn != current_conn: msg = Messages.get_formatted("PluginManager.MethodInvokedAgainstOldConnection", target) raise AwsWrapperError(msg) if method_name == "Connection.set_read_only" and args is not None and len(args) > 0: self._switch_connection_if_required(args[0]) try: return execute_func() except Exception as ex: if isinstance(ex, FailoverError): logger.debug("ReadWriteSplittingPlugin.FailoverExceptionWhileExecutingCommand", method_name) self._close_idle_connections() else: logger.debug("ReadWriteSplittingPlugin.ExceptionWhileExecutingCommand", method_name) raise ex def _update_internal_connection_info(self): current_conn = self._plugin_service.current_connection current_host = self._plugin_service.current_host_info if current_conn is None or current_host is None: return if current_host.role == HostRole.WRITER: self._set_writer_connection(current_conn, current_host) else: self._set_reader_connection(current_conn, current_host) def _set_writer_connection(self, writer_conn: Connection, writer_host_info: HostInfo): self._writer_connection = writer_conn logger.debug("ReadWriteSplittingPlugin.SetWriterConnection", writer_host_info.url) def _set_reader_connection(self, reader_conn: Connection, reader_host_info: HostInfo): self._reader_connection = reader_conn self._reader_host_info = reader_host_info logger.debug("ReadWriteSplittingPlugin.SetReaderConnection", reader_host_info.url) def _get_new_writer_connection(self, writer_host: HostInfo): conn = self._plugin_service.connect(writer_host, self._properties) provider = self._conn_provider_manager.get_connection_provider(writer_host, self._properties) self._is_writer_conn_from_internal_pool = (ReadWriteSplittingPlugin._POOL_PROVIDER_CLASS_NAME in str(type(provider))) self._set_writer_connection(conn, writer_host) self._switch_current_connection_to(conn, writer_host) def _switch_connection_if_required(self, read_only: bool): current_conn = self._plugin_service.current_connection driver_dialect = self._plugin_service.driver_dialect if (current_conn is not None and driver_dialect is not None and driver_dialect.is_closed(current_conn)): self._log_and_raise_exception("ReadWriteSplittingPlugin.SetReadOnlyOnClosedConnection") if current_conn is not None and driver_dialect.can_execute_query(current_conn): try: self._plugin_service.refresh_host_list() except Exception: pass # Swallow exception hosts = self._plugin_service.hosts if hosts is None or len(hosts) == 0: self._log_and_raise_exception("ReadWriteSplittingPlugin.EmptyHostList") current_host = self._plugin_service.current_host_info if current_host is None: self._log_and_raise_exception("ReadWriteSplittingPlugin.UnavailableHostInfo") return if read_only: if not self._plugin_service.is_in_transaction and current_host.role != HostRole.READER: try: self._switch_to_reader_connection(hosts) except Exception: if not self._is_connection_usable(current_conn, driver_dialect): self._log_and_raise_exception("ReadWriteSplittingPlugin.ErrorSwitchingToReader") return logger.warning("ReadWriteSplittingPlugin.FallbackToWriter", current_host.url) elif current_host.role != HostRole.WRITER: if self._plugin_service.is_in_transaction: self._log_and_raise_exception("ReadWriteSplittingPlugin.SetReadOnlyFalseInTransaction") try: self._switch_to_writer_connection(hosts) except Exception: self._log_and_raise_exception("ReadWriteSplittingPlugin.ErrorSwitchingToWriter") def _switch_current_connection_to(self, new_conn: Connection, new_conn_host: HostInfo): current_conn = self._plugin_service.current_connection if current_conn == new_conn: return self._plugin_service.set_current_connection(new_conn, new_conn_host) logger.debug("ReadWriteSplittingPlugin.SettingCurrentConnection", new_conn_host.url) def _switch_to_writer_connection(self, hosts: Tuple[HostInfo, ...]): current_host = self._plugin_service.current_host_info current_conn = self._plugin_service.current_connection driver_dialect = self._plugin_service.driver_dialect if (current_host is not None and current_host.role == HostRole.WRITER and self._is_connection_usable(current_conn, driver_dialect)): return writer_host = self._get_writer(hosts) if writer_host is None: return self._in_read_write_split = True if not self._is_connection_usable(self._writer_connection, driver_dialect): self._get_new_writer_connection(writer_host) elif self._writer_connection is not None: self._switch_current_connection_to(self._writer_connection, writer_host) if self._is_reader_conn_from_internal_pool: self._close_connection_if_idle(self._reader_connection) logger.debug("ReadWriteSplittingPlugin.SwitchedFromReaderToWriter", writer_host.url) def _switch_to_reader_connection(self, hosts: Tuple[HostInfo, ...]): current_host = self._plugin_service.current_host_info current_conn = self._plugin_service.current_connection driver_dialect = self._plugin_service.driver_dialect if (current_host is not None and current_host.role == HostRole.READER and self._is_connection_usable(current_conn, driver_dialect)): return hostnames = [host_info.host for host_info in hosts] if self._reader_host_info is not None and self._reader_host_info.host not in hostnames: # The old reader cannot be used anymore because it is no longer in the list of allowed hosts. self._close_connection_if_idle(self._reader_connection) self._in_read_write_split = True if not self._is_connection_usable(self._reader_connection, driver_dialect): self._initialize_reader_connection(hosts) elif self._reader_connection is not None and self._reader_host_info is not None: try: self._switch_current_connection_to(self._reader_connection, self._reader_host_info) logger.debug("ReadWriteSplittingPlugin.SwitchedFromWriterToReader", self._reader_host_info.url) except Exception: logger.debug("ReadWriteSplittingPlugin.ErrorSwitchingToCachedReader", self._reader_host_info.url) self._reader_connection.close() self._reader_connection = None self._reader_host_info = None self._initialize_reader_connection(hosts) if self._is_writer_conn_from_internal_pool: self._close_connection_if_idle(self._writer_connection) def _initialize_reader_connection(self, hosts: Tuple[HostInfo, ...]): if len(hosts) == 1: writer_host = self._get_writer(hosts) if writer_host is not None: if not self._is_connection_usable(self._writer_connection, self._plugin_service.driver_dialect): self._get_new_writer_connection(writer_host) logger.warning("ReadWriteSplittingPlugin.NoReadersFound", writer_host.url) return conn: Optional[Connection] = None reader_host: Optional[HostInfo] = None conn_attempts = len(self._plugin_service.hosts) * 2 for _ in range(conn_attempts): host = self._plugin_service.get_host_info_by_strategy(HostRole.READER, self._reader_selector_strategy) if host is not None: try: conn = self._plugin_service.connect(host, self._properties) provider = self._conn_provider_manager.get_connection_provider(host, self._properties) self._is_reader_conn_from_internal_pool = (ReadWriteSplittingPlugin._POOL_PROVIDER_CLASS_NAME in str(type(provider))) reader_host = host break except Exception: logger.warning("ReadWriteSplittingPlugin.FailedToConnectToReader", host.url) if conn is None or reader_host is None: self._log_and_raise_exception("ReadWriteSplittingPlugin.NoReadersAvailable") return logger.debug("ReadWriteSplittingPlugin.SuccessfullyConnectedToReader", reader_host.url) self._set_reader_connection(conn, reader_host) self._switch_current_connection_to(conn, reader_host) logger.debug("ReadWriteSplittingPlugin.SwitchedFromWriterToReader", reader_host.url) def _close_connection_if_idle(self, internal_conn: Optional[Connection]): current_conn = self._plugin_service.current_connection driver_dialect = self._plugin_service.driver_dialect try: if (internal_conn is not None and internal_conn != current_conn and self._is_connection_usable(internal_conn, driver_dialect)): internal_conn.close() if internal_conn == self._writer_connection: self._writer_connection = None if internal_conn == self._reader_connection: self._reader_connection = None self._reader_host_info = None except Exception: pass # Swallow exception def _close_idle_connections(self): logger.debug("ReadWriteSplittingPlugin.ClosingInternalConnections") self._close_connection_if_idle(self._reader_connection) self._close_connection_if_idle(self._writer_connection) @staticmethod def _log_and_raise_exception(log_msg: str): logger.error(log_msg) raise ReadWriteSplittingError(Messages.get(log_msg)) @staticmethod def _is_connection_usable(conn: Optional[Connection], driver_dialect: Optional[DriverDialect]): return conn is not None and driver_dialect is not None and not driver_dialect.is_closed(conn) @staticmethod def _get_writer(hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]: for host in hosts: if host.role == HostRole.WRITER: return host ReadWriteSplittingPlugin._log_and_raise_exception("ReadWriteSplittingPlugin.NoWriterFound") return None class ReadWriteSplittingPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: return ReadWriteSplittingPlugin(plugin_service, props)