aws_advanced_python_wrapper/stale_dns_plugin.py (146 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import socket from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.host_list_provider import HostListProviderService from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService from aws_advanced_python_wrapper.utils.properties import Properties from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.hostinfo import HostRole from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.notifications import HostEvent from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils from aws_advanced_python_wrapper.utils.utils import LogUtils logger = Logger(__name__) class StaleDnsHelper: RETRIES: int = 3 def __init__(self, plugin_service: PluginService) -> None: self._plugin_service = plugin_service self._rds_helper = RdsUtils() self._writer_host_info: Optional[HostInfo] = None self._writer_host_address: Optional[str] = None def get_verified_connection(self, is_initial_connection: bool, host_list_provider_service: HostListProviderService, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection: """ Ensure the connection created is not a stale writer connection that :param is_initial_connection: :param host_list_provider_service: :param host_info: :param props: :param connect_func: :return: """ if not self._rds_helper.is_writer_cluster_dns(host_info.host): return connect_func() conn: Connection = connect_func() cluster_inet_address: Optional[str] = None try: cluster_inet_address = socket.gethostbyname(host_info.host) except socket.gaierror: pass host_inet_address: Optional[str] = cluster_inet_address logger.debug("StaleDnsHelper.ClusterEndpointDns", host_info.host, host_inet_address) if cluster_inet_address is None: return conn if self._plugin_service.get_host_role(conn) == HostRole.READER: # This if-statement is only reached if the connection url is a writer cluster endpoint. # If the new connection resolves to a reader instance, this means the topology is outdated. # Force refresh to update the topology. self._plugin_service.force_refresh_host_list(conn) else: self._plugin_service.refresh_host_list(conn) logger.debug("LogUtils.Topology", LogUtils.log_topology(self._plugin_service.all_hosts)) if self._writer_host_info is None: writer_candidate: Optional[HostInfo] = self._get_writer() if writer_candidate is not None and self._rds_helper.is_rds_cluster_dns(writer_candidate.host): return conn self._writer_host_info = writer_candidate logger.debug("StaleDnsHelper.WriterHostSpec", self._writer_host_info) if self._writer_host_info is None: return conn if self._writer_host_address is None: try: self._writer_host_address = socket.gethostbyname(self._writer_host_info.host) except socket.gaierror: pass logger.debug("StaleDnsHelper.WriterInetAddress", self._writer_host_address) if self._writer_host_address is None: return conn if self._writer_host_address != cluster_inet_address: logger.debug("StaleDnsHelper.StaleDnsDetected", self._writer_host_info) allowed_hosts = self._plugin_service.hosts allowed_hostnames = [host.host for host in allowed_hosts] if self._writer_host_info.host not in allowed_hostnames: raise AwsWrapperError( Messages.get_formatted( "StaleDnsHelper.CurrentWriterNotAllowed", "<null>" if self._writer_host_info is None else self._writer_host_info.host, LogUtils.log_topology(allowed_hosts))) writer_conn: Connection = self._plugin_service.connect(self._writer_host_info, props) if is_initial_connection: host_list_provider_service.initial_connection_host_info = self._writer_host_info if conn is not None: try: conn.close() except Exception: pass return writer_conn return conn def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]) -> None: if self._writer_host_info is None: return writer_changes = changes.get(self._writer_host_info.url, None) if writer_changes is not None and HostEvent.CONVERTED_TO_READER in writer_changes: logger.debug("StaleDnsHelper.Reset") self._writer_host_info = None self._writer_host_address = None def _get_writer(self) -> Optional[HostInfo]: for host in self._plugin_service.all_hosts: if host.role == HostRole.WRITER: return host return None class StaleDnsPlugin(Plugin): _SUBSCRIBED_METHODS: Set[str] = {"init_host_provider", "connect", "force_connect", "notify_host_list_changed"} def __init__(self, plugin_service: PluginService) -> None: self._plugin_service = plugin_service self._stale_dns_helper = StaleDnsHelper(self._plugin_service) StaleDnsPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods) @property def subscribed_methods(self) -> Set[str]: return self._SUBSCRIBED_METHODS def connect( self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Connection: return self._stale_dns_helper.get_verified_connection( is_initial_connection, self._host_list_provider_service, host_info, props, connect_func) def force_connect( self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, is_initial_connection: bool, force_connect_func: Callable) -> Connection: return self._stale_dns_helper.get_verified_connection( is_initial_connection, self._host_list_provider_service, host_info, props, force_connect_func) def execute(self, target: type, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: try: self._plugin_service.refresh_host_list() except Exception: pass return execute_func() def init_host_provider( self, properties: Properties, host_list_provider_service: HostListProviderService, init_host_provider_func: Callable): self._host_list_provider_service = host_list_provider_service init_host_provider_func() if self._host_list_provider_service.is_static_host_list_provider(): raise Exception(Messages.get_formatted("StaleDnsPlugin.RequireDynamicProvider")) def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]): self._stale_dns_helper.notify_host_list_changed(changes) class StaleDnsPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: return StaleDnsPlugin(plugin_service)