aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py (165 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from time import perf_counter_ns, sleep
from typing import TYPE_CHECKING, Callable, Optional, Set
if TYPE_CHECKING:
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.host_list_provider import HostListProviderService
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.host_availability import HostAvailability
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
class AuroraInitialConnectionStrategyPlugin(Plugin):
_SUBSCRIBED_METHODS: Set[str] = {"init_host_provider",
"connect",
"force_connect"}
_host_list_provider_service: Optional[HostListProviderService] = None
@property
def subscribed_methods(self) -> Set[str]:
return AuroraInitialConnectionStrategyPlugin._SUBSCRIBED_METHODS
def __init__(self, plugin_service: PluginService):
super()
self._plugin_service = plugin_service
self._rds_utils = RdsUtils()
def connect(self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties,
is_initial_connection: bool, connect_func: Callable) -> Connection:
return self._connect(host_info, props, is_initial_connection, connect_func)
def force_connect(self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties,
is_initial_connection: bool, force_connect_func: Callable) -> Connection:
return self._connect(host_info, props, is_initial_connection, force_connect_func)
def _connect(self, host_info: HostInfo, props: Properties, is_initial_connection: bool, connect_func: Callable):
type: RdsUrlType = self._rds_utils.identify_rds_type(host_info.host)
if not type.is_rds_cluster:
return connect_func()
if type == RdsUrlType.RDS_WRITER_CLUSTER:
writer_candidate_conn: Optional[Connection] = self._get_verified_writer_connection(props, is_initial_connection, connect_func)
if writer_candidate_conn is None:
return connect_func()
return writer_candidate_conn
if type == RdsUrlType.RDS_READER_CLUSTER:
reader_candidate_conn: Optional[Connection] = self._get_verified_reader_connection(props, is_initial_connection, connect_func)
if reader_candidate_conn is None:
return connect_func()
return reader_candidate_conn
def _get_verified_writer_connection(self, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Connection | None:
retry_delay_ms: int = WrapperProperties.OPEN_CONNECTION_RETRY_INTERVAL_MS.get_int(props)
end_time_nano = perf_counter_ns() + (WrapperProperties.OPEN_CONNECTION_RETRY_INTERVAL_MS.get_int(props) * 1000000)
writer_candidate_conn: Optional[Connection]
writer_candidate: Optional[HostInfo]
while perf_counter_ns() < end_time_nano:
writer_candidate_conn = None
writer_candidate = None
try:
writer_candidate = self._get_writer()
if writer_candidate is None or self._rds_utils.is_rds_cluster_dns(writer_candidate.host):
# Writer is not found. Topology is outdated.
writer_candidate_conn = connect_func()
self._plugin_service.force_refresh_host_list(writer_candidate_conn)
writer_candidate = self._plugin_service.identify_connection(writer_candidate_conn)
if writer_candidate is not None and writer_candidate.role != HostRole.WRITER:
self._close_connection(writer_candidate_conn)
self._delay(retry_delay_ms)
continue
if is_initial_connection and self._host_list_provider_service is not None:
self._host_list_provider_service.initial_connection_host_info = writer_candidate
return writer_candidate_conn
writer_candidate_conn = self._plugin_service.connect(writer_candidate, props)
if self._plugin_service.get_host_role(writer_candidate_conn) != HostRole.WRITER:
self._plugin_service.force_refresh_host_list(writer_candidate_conn)
self._close_connection(writer_candidate_conn)
self._delay(retry_delay_ms)
continue
# Writer connection is valid and verified.
if is_initial_connection and self._host_list_provider_service is not None:
self._host_list_provider_service.initial_connection_host_info = writer_candidate
return writer_candidate_conn
except Exception as e:
self._close_connection(writer_candidate_conn)
raise e
return None
def _get_verified_reader_connection(self, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Optional[Connection]:
retry_delay_ms: int = WrapperProperties.OPEN_CONNECTION_RETRY_INTERVAL_MS.get_int(props)
end_time_nano = perf_counter_ns() + (WrapperProperties.OPEN_CONNECTION_RETRY_INTERVAL_MS.get_int(props) * 1000000)
reader_candidate_conn: Optional[Connection]
reader_candidate: Optional[HostInfo]
while perf_counter_ns() < end_time_nano:
reader_candidate_conn = None
reader_candidate = None
try:
reader_candidate = self._get_reader(props)
if reader_candidate is None or self._rds_utils.is_rds_cluster_dns(reader_candidate.host):
# READER is not found. Topology is outdated.
reader_candidate_conn = connect_func()
self._plugin_service.force_refresh_host_list(reader_candidate_conn)
reader_candidate = self._plugin_service.identify_connection(reader_candidate_conn)
if reader_candidate is not None and reader_candidate.role != HostRole.READER:
if self._has_no_readers():
# Cluster has no readers. Simulate Aurora reader cluster endpoint logic and return the current writer connection.
if is_initial_connection and self._host_list_provider_service is not None:
self._host_list_provider_service.initial_connection_host_info = reader_candidate
return reader_candidate_conn
self._close_connection(reader_candidate_conn)
self._delay(retry_delay_ms)
continue
if is_initial_connection and self._host_list_provider_service is not None:
self._host_list_provider_service.initial_connection_host_info = reader_candidate
return reader_candidate_conn
reader_candidate_conn = self._plugin_service.connect(reader_candidate, props)
if self._plugin_service.get_host_role(reader_candidate_conn) != HostRole.READER:
# If the new connection resolves to a writer instance, the topology is outdated.
# Force refresh to update the topology.
self._plugin_service.force_refresh_host_list(reader_candidate_conn)
if self._has_no_readers():
# Cluster has no readers. Simulate Aurora reader cluster endpoint logic and return the current writer connection.
if is_initial_connection and self._host_list_provider_service is not None:
self._host_list_provider_service.initial_connection_host_info = reader_candidate
return reader_candidate_conn
self._close_connection(reader_candidate_conn)
self._delay(retry_delay_ms)
continue
# Reader connection is valid and verified.
if is_initial_connection and self._host_list_provider_service is not None:
self._host_list_provider_service.initial_connection_host_info = reader_candidate
return reader_candidate_conn
except Exception as e:
self._close_connection(reader_candidate_conn)
if not self._plugin_service.is_login_exception(e) and reader_candidate is not None:
self._plugin_service.set_availability(reader_candidate.as_aliases(), HostAvailability.UNAVAILABLE)
raise e
return None
def _close_connection(self, connection: Optional[Connection]):
if connection is not None:
try:
connection.close()
except Exception:
# ignore
pass
def _delay(self, delay_ms: int):
sleep(delay_ms / 1000)
def _get_writer(self) -> Optional[HostInfo]:
for host in self._plugin_service.all_hosts:
if host.role == HostRole.WRITER:
return host
return None
def _get_reader(self, props: Properties) -> Optional[HostInfo]:
strategy = WrapperProperties.READER_INITIAL_HOST_SELECTOR_STRATEGY.get(props)
if (self._plugin_service is not None
and strategy is not None
and self._plugin_service.accepts_strategy(HostRole.READER, strategy)):
try:
return self._plugin_service.get_host_info_by_strategy(HostRole.READER, strategy)
except Exception:
# Host isn't found.
return None
raise AwsWrapperError(Messages.get_formatted("AuroraInitialConnectionStrategyPlugin.UnsupportedStrategy", strategy))
def init_host_provider(self, props: Properties, host_list_provider_service: HostListProviderService, init_host_provider_func: Callable):
self._host_list_provider_service = host_list_provider_service
if host_list_provider_service.is_static_host_list_provider():
raise AwsWrapperError(Messages.get("AuroraInitialConnectionStrategyPlugin.RequireDynamicProvider"))
init_host_provider_func(props)
def _has_no_readers(self) -> bool:
if len(self._plugin_service.all_hosts) == 0:
return False
for host in self._plugin_service.all_hosts:
if host.role == HostRole.READER:
return False
return True
class AuroraInitialConnectionStrategyPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return AuroraInitialConnectionStrategyPlugin(plugin_service)