aws_advanced_python_wrapper/failover_plugin.py (383 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Tuple
if TYPE_CHECKING:
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.failover_result import ReaderFailoverResult, WriterFailoverResult
from aws_advanced_python_wrapper.host_list_provider import HostListProviderService
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService
from typing import Any, Callable, Dict, Optional, Set
from aws_advanced_python_wrapper import LogUtils
from aws_advanced_python_wrapper.errors import (
AwsWrapperError, FailoverFailedError, FailoverSuccessError,
TransactionResolutionUnknownError)
from aws_advanced_python_wrapper.host_availability import HostAvailability
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.reader_failover_handler import (
ReaderFailoverHandler, ReaderFailoverHandlerImpl)
from aws_advanced_python_wrapper.stale_dns_plugin import StaleDnsHelper
from aws_advanced_python_wrapper.utils.failover_mode import (FailoverMode,
get_failover_mode)
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.notifications import HostEvent
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
TelemetryTraceLevel
from aws_advanced_python_wrapper.writer_failover_handler import (
WriterFailoverHandler, WriterFailoverHandlerImpl)
logger = Logger(__name__)
class FailoverPlugin(Plugin):
"""
This plugin provides cluster-aware failover features.
The plugin switches connections upon detecting communication related exceptions and/or cluster topology changes.
"""
_SUBSCRIBED_METHODS: Set[str] = {"init_host_provider",
"connect",
"force_connect",
"notify_host_list_changed"}
_METHODS_REQUIRE_UPDATED_TOPOLOGY: Set[str] = {
"Connection.commit",
"Connection.autocommit",
"Connection.autocommit_setter",
"Connection.rollback",
"Connection.cursor",
"Cursor.callproc",
"Cursor.execute"
}
def __init__(self, plugin_service: PluginService, props: Properties):
self._plugin_service = plugin_service
self._properties = props
self._host_list_provider_service: HostListProviderService
self._reader_failover_handler: ReaderFailoverHandler
self._writer_failover_handler: WriterFailoverHandler
self._enable_failover_setting = WrapperProperties.ENABLE_FAILOVER.get_bool(self._properties)
self._failover_timeout_sec = WrapperProperties.FAILOVER_TIMEOUT_SEC.get_float(self._properties)
self._failover_cluster_topology_refresh_rate_sec = WrapperProperties.FAILOVER_CLUSTER_TOPOLOGY_REFRESH_RATE_SEC.get_float(
self._properties)
self._failover_writer_reconnect_interval_sec = WrapperProperties.FAILOVER_WRITER_RECONNECT_INTERVAL_SEC.get_float(
self._properties)
self._failover_reader_connect_timeout_sec = WrapperProperties.FAILOVER_READER_CONNECT_TIMEOUT_SEC.get_float(
self._properties)
self._telemetry_failover_additional_top_trace_setting = (
WrapperProperties.TELEMETRY_FAILOVER_ADDITIONAL_TOP_TRACE.get_bool(self._properties))
self._failover_mode: FailoverMode
self._is_in_transaction: bool = False
self._is_closed: bool = False
self._closed_explicitly: bool = False
self._last_exception: Optional[Exception] = None
self._rds_utils = RdsUtils()
self._rds_url_type: RdsUrlType = self._rds_utils.identify_rds_type(self._properties.get("host"))
self._stale_dns_helper: StaleDnsHelper = StaleDnsHelper(plugin_service)
self._saved_read_only_status: bool = False
self._saved_auto_commit_status: bool = False
telemetry_factory = self._plugin_service.get_telemetry_factory()
self._failover_writer_triggered_counter = telemetry_factory.create_counter("writer_failover.triggered.count")
self._failover_writer_success_counter = telemetry_factory.create_counter(
"writer_failover.completed.success.count")
self._failover_writer_failed_counter = telemetry_factory.create_counter(
"writer_failover.completed.failed.count")
self._failover_reader_triggered_counter = telemetry_factory.create_counter("reader_failover.triggered.count")
self._failover_reader_success_counter = telemetry_factory.create_counter(
"reader_failover.completed.success.count")
self._failover_reader_failed_counter = telemetry_factory.create_counter(
"reader_failover.completed.failed.count")
FailoverPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods)
def init_host_provider(
self,
properties: Properties,
host_list_provider_service: HostListProviderService,
init_host_provider_func: Callable):
self._host_list_provider_service = host_list_provider_service
if not self._enable_failover_setting:
return
self._reader_failover_handler = ReaderFailoverHandlerImpl(self._plugin_service, self._properties,
self._failover_timeout_sec,
self._failover_reader_connect_timeout_sec)
self._writer_failover_handler = WriterFailoverHandlerImpl(self._plugin_service, self._reader_failover_handler,
self._properties,
self._failover_timeout_sec,
self._failover_cluster_topology_refresh_rate_sec,
self._failover_writer_reconnect_interval_sec)
init_host_provider_func()
failover_mode = get_failover_mode(self._properties)
if failover_mode is None:
if self._rds_url_type.is_rds_cluster:
if self._rds_url_type == RdsUrlType.RDS_READER_CLUSTER:
failover_mode = FailoverMode.READER_OR_WRITER
else:
failover_mode = FailoverMode.STRICT_WRITER
else:
failover_mode = FailoverMode.STRICT_WRITER
self._failover_mode = failover_mode
logger.debug("FailoverPlugin.ParameterValue", "FAILOVER_MODE", self._failover_mode)
@property
def subscribed_methods(self) -> Set[str]:
return self._SUBSCRIBED_METHODS
def execute(self, target: type, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any:
self._is_in_transaction = self._plugin_service.is_in_transaction
if not self._enable_failover_setting or self._can_direct_execute(method_name):
return execute_func()
if self._is_closed and not self._allowed_on_closed_connection(method_name):
self._invalid_invocation_on_closed_connection()
try:
if self._requires_update_topology(method_name):
self._update_topology(False)
return execute_func()
except Exception as ex:
logger.debug("FailoverPlugin.DetectedException", str(ex))
if self._last_exception != ex and self._should_exception_trigger_connection_switch(ex):
self._invalidate_current_connection()
if self._plugin_service.current_host_info is not None:
self._plugin_service.set_availability(
self._plugin_service.current_host_info.aliases, HostAvailability.UNAVAILABLE)
self._pick_new_connection()
self._last_exception = ex
raise AwsWrapperError(Messages.get_formatted("FailoverPlugin.DetectedException", str(ex))) from ex
def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]):
if not self._enable_failover_setting:
return
msg = ""
for key in changes:
msg += f"\n\tHost '{key}': {changes[key]}"
logger.debug("FailoverPlugin.Changes", msg)
current_host = self._plugin_service.current_host_info
if current_host is not None:
if self._is_host_still_valid(current_host.url, changes):
return
for alias in current_host.aliases:
if self._is_host_still_valid(alias + '/', changes):
return
logger.debug("FailoverPlugin.InvalidHost", current_host)
def connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
return self._connect(host_info, props, is_initial_connection, connect_func)
def force_connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
force_connect_func: Callable) -> Connection:
return self._connect(host_info, props, is_initial_connection, force_connect_func)
def _connect(
self,
host: HostInfo,
properties: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
conn: Connection = self._stale_dns_helper.get_verified_connection(is_initial_connection,
self._host_list_provider_service, host,
properties,
connect_func)
if is_initial_connection:
self._plugin_service.refresh_host_list(conn)
return conn
def _update_topology(self, force_update: bool):
if not self._is_failover_enabled():
return
conn = self._plugin_service.current_connection
driver_dialect = self._plugin_service.driver_dialect
if conn is None or (driver_dialect is not None and driver_dialect.is_closed(conn)):
return
if force_update:
self._plugin_service.force_refresh_host_list()
else:
self._plugin_service.refresh_host_list()
def _failover(self, failed_host: Optional[HostInfo]):
"""
Initiates the failover procedure. This process tries to establish a new connection to an instance in the topology.
:param failed_host: The host with network errors.
"""
if failed_host is not None:
self._plugin_service.set_availability(failed_host.as_aliases(), HostAvailability.UNAVAILABLE)
if self._failover_mode == FailoverMode.STRICT_WRITER:
self._failover_writer()
else:
self._failover_reader(failed_host)
if self._is_in_transaction or self._plugin_service.is_in_transaction:
self._plugin_service.update_in_transaction(False)
error_msg = "FailoverPlugin.TransactionResolutionUnknownError"
logger.warning(error_msg)
raise TransactionResolutionUnknownError(Messages.get(error_msg))
else:
error_msg = "FailoverPlugin.ConnectionChangedError"
logger.error(error_msg)
raise FailoverSuccessError(Messages.get(error_msg))
def _failover_reader(self, failed_host: Optional[HostInfo]):
telemetry_factory = self._plugin_service.get_telemetry_factory()
context = telemetry_factory.open_telemetry_context("failover to replica", TelemetryTraceLevel.NESTED)
self._failover_reader_triggered_counter.inc()
try:
logger.info("FailoverPlugin.StartReaderFailover")
old_aliases = None
if self._plugin_service.current_host_info is not None:
old_aliases = self._plugin_service.current_host_info.aliases
if failed_host is not None and failed_host.get_raw_availability() != HostAvailability.AVAILABLE:
failed_host = None
result: ReaderFailoverResult = self._reader_failover_handler.failover(self._plugin_service.hosts,
failed_host)
if result is None or not result.is_connected:
raise FailoverFailedError(Messages.get("FailoverPlugin.UnableToConnectToReader"))
else:
if result.exception is not None:
raise result.exception
if result.connection is not None and result.new_host is not None:
self._plugin_service.set_current_connection(result.connection, result.new_host)
if self._plugin_service.current_host_info is not None and old_aliases is not None and len(old_aliases) > 0:
self._plugin_service.current_host_info.remove_alias(old_aliases)
self._update_topology(True)
logger.info("FailoverPlugin.EstablishedConnection", self._plugin_service.current_host_info)
self._failover_reader_success_counter.inc()
except FailoverSuccessError as fse:
context.set_success(True)
context.set_exception(fse)
self._failover_reader_success_counter.inc()
raise fse
except Exception as ex:
context.set_success(False)
context.set_exception(ex)
self._failover_reader_failed_counter.inc()
raise ex
finally:
context.close_context()
if self._telemetry_failover_additional_top_trace_setting:
telemetry_factory.post_copy(context, TelemetryTraceLevel.FORCE_TOP_LEVEL)
def _failover_writer(self):
telemetry_factory = self._plugin_service.get_telemetry_factory()
context = telemetry_factory.open_telemetry_context("failover to writer host", TelemetryTraceLevel.NESTED)
self._failover_writer_triggered_counter.inc()
try:
logger.info("FailoverPlugin.StartWriterFailover")
result: WriterFailoverResult = self._writer_failover_handler.failover(self._plugin_service.all_hosts)
if result is not None and result.exception is not None:
raise result.exception
elif result is None or not result.is_connected:
raise FailoverFailedError(Messages.get("FailoverPlugin.UnableToConnectToWriter"))
writer_host = self._get_writer(result.topology)
allowed_hosts = self._plugin_service.hosts
allowed_hostnames = [host.host for host in allowed_hosts]
if writer_host.host not in allowed_hostnames:
raise FailoverFailedError(
Messages.get_formatted(
"FailoverPlugin.NewWriterNotAllowed",
"<null>" if writer_host is None else writer_host.host,
LogUtils.log_topology(allowed_hosts)))
self._plugin_service.set_current_connection(result.new_connection, writer_host)
logger.info("FailoverPlugin.EstablishedConnection", self._plugin_service.current_host_info)
self._plugin_service.refresh_host_list()
self._failover_writer_success_counter.inc()
except FailoverSuccessError as fse:
context.set_success(True)
context.set_exception(fse)
self._failover_writer_success_counter.inc()
raise fse
except Exception as ex:
context.set_success(False)
context.set_exception(ex)
self._failover_writer_failed_counter.inc()
raise ex
finally:
context.close_context()
if self._telemetry_failover_additional_top_trace_setting:
telemetry_factory.post_copy(context, TelemetryTraceLevel.FORCE_TOP_LEVEL)
def _invalidate_current_connection(self):
"""
Invalidate the current connection before switching to a new connection.
"""
conn = self._plugin_service.current_connection
if conn is None:
return
if self._plugin_service.is_in_transaction:
self._plugin_service.update_in_transaction(True)
try:
conn.rollback()
except Exception:
pass
driver_dialect = self._plugin_service.driver_dialect
if driver_dialect is not None and not driver_dialect.is_closed(conn):
try:
conn.close()
except Exception:
pass
def _invalid_invocation_on_closed_connection(self):
if not self._closed_explicitly:
self._is_closed = False
self._pick_new_connection()
error_msg = "FailoverPlugin.ConnectionChangedError"
logger.debug(error_msg)
raise FailoverSuccessError(Messages.get(error_msg))
else:
raise AwsWrapperError(Messages.get("FailoverPlugin.NoOperationsAfterConnectionClosed"))
def _pick_new_connection(self):
if self._is_closed and self._closed_explicitly:
logger.debug("FailoverPlugin.NoOperationsAfterConnectionClosed")
return
if self._plugin_service.current_connection is None and not self._should_attempt_reader_connection():
writer = self._get_current_writer()
try:
self._connect_to(writer)
except Exception:
self._failover(writer)
else:
self._failover(self._plugin_service.current_host_info)
def _connect_to(self, host: HostInfo):
"""
Connects this dynamic failover connection proxy to the specified host.
:param host: The host to connect to.
"""
try:
connection_for_host = self._plugin_service.connect(host, self._properties)
self._plugin_service.set_current_connection(connection_for_host, host)
self._plugin_service.update_in_transaction(False)
logger.info("FailoverPlugin.EstablishedConnection", host)
except Exception as ex:
if self._plugin_service is not None:
logger.debug("FailoverPlugin.ConnectionToHostFailed",
'writer' if host.role == HostRole.WRITER else 'reader', host.url)
raise ex
def _should_attempt_reader_connection(self) -> bool:
topology = self._plugin_service.hosts
if topology is None or self._failover_mode == FailoverMode.STRICT_WRITER:
return False
for host in topology:
if host.role == HostRole.READER:
return True
return False
def _is_failover_enabled(self) -> bool:
return self._enable_failover_setting and \
self._rds_url_type != RdsUrlType.RDS_PROXY and \
self._plugin_service.all_hosts is not None and \
len(self._plugin_service.all_hosts) > 0
def _get_current_writer(self) -> Optional[HostInfo]:
topology = self._plugin_service.all_hosts
if topology is None:
return None
return self._get_writer(topology)
def _should_exception_trigger_connection_switch(self, ex: Exception) -> bool:
"""
Checks whether the given exception is a network exception and should trigger the failover process.
:param ex: The exception raised during the method call.
:return: `True` if the exception should trigger failover. `False` otherwise.
"""
if not self._is_failover_enabled():
logger.debug("FailoverPlugin.FailoverDisabled")
return False
return self._plugin_service.is_network_exception(ex)
@staticmethod
def _get_writer(hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]:
for host in hosts:
if host.role == HostRole.WRITER:
return host
return None
@staticmethod
def _is_host_still_valid(host: str, changes: Dict[str, Set[HostEvent]]):
if host in changes:
options = changes.get(host)
return options is not None and \
HostEvent.HOST_DELETED not in options and HostEvent.WENT_DOWN not in options
return True
@staticmethod
def _can_direct_execute(method_name):
"""
Check whether the method provided can be executed directly without the failover functionality.
:param method_name: The name of the method that is being called.
:return: `True` if the method can be executed directly; `False` otherwise.
"""
return method_name == "Connection.close" or \
method_name == "Connection.is_closed" or \
method_name == "Cursor.close"
@staticmethod
def _allowed_on_closed_connection(method_name: str):
"""
Checks if the given method is allowed on closed connections.
:param method_name: The method being executed at the moment.
:return: `True` if the given method is allowed on closed connections.
"""
return method_name == "Connection.autocommit"
def _requires_update_topology(self, method_name: str):
"""
Not all method calls require an updated topology, especially ones that don't require network connection.
Updating the topology may execute the topology query in the middle of another query execution,
this introduces overhead and may not be supported by all drivers.
:param method_name: The method being executed at the moment.
:return: `True` if the given method requires an updated topology before executing. `False` otherwise.
"""
return method_name in FailoverPlugin._METHODS_REQUIRE_UPDATED_TOPOLOGY
class FailoverPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return FailoverPlugin(plugin_service, props)