aws_advanced_python_wrapper/custom_endpoint_plugin.py (258 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from threading import Event, Thread
from time import perf_counter_ns, sleep
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, List,
Optional, Set, Union, cast)
from aws_advanced_python_wrapper.allowed_and_blocked_hosts import \
AllowedAndBlockedHosts
from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.region_utils import RegionUtils
if TYPE_CHECKING:
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.utils.properties import Properties
from enum import Enum
from boto3 import Session
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
SlidingExpirationCacheWithCleanupThread
from aws_advanced_python_wrapper.utils.telemetry.telemetry import (
TelemetryCounter, TelemetryFactory)
logger = Logger(__name__)
class CustomEndpointRoleType(Enum):
"""
Enum representing the possible roles of instances specified by a custom endpoint. Note that, currently, it is not
possible to create a WRITER custom endpoint.
"""
ANY = "ANY"
READER = "READER"
@classmethod
def from_string(cls, value):
return CustomEndpointRoleType(value)
class CustomEndpointInfo:
def __init__(self,
endpoint_id: str,
cluster_id: str,
endpoint: str,
role_type: CustomEndpointRoleType,
static_members: Optional[Set[str]],
excluded_members: Optional[Set[str]]):
self.endpoint_id = endpoint_id
self.cluster_id = cluster_id
self.endpoint = endpoint
self.role_type = role_type
self.static_members = None if not static_members else static_members
self.excluded_members = None if not excluded_members else excluded_members
@classmethod
def from_db_cluster_endpoint(cls, endpoint_response_info: Dict[str, Union[str, List[str]]]):
return CustomEndpointInfo(
str(endpoint_response_info.get("DBClusterEndpointIdentifier")),
str(endpoint_response_info.get("DBClusterIdentifier")),
str(endpoint_response_info.get("Endpoint")),
CustomEndpointRoleType.from_string(str(endpoint_response_info.get("CustomEndpointType"))),
set(cast('List[str]', endpoint_response_info.get("StaticMembers"))),
set(cast('List[str]', endpoint_response_info.get("ExcludedMembers")))
)
def __eq__(self, other: object):
if self is object:
return True
if not isinstance(other, CustomEndpointInfo):
return False
return self.endpoint_id == other.endpoint_id \
and self.cluster_id == other.cluster_id \
and self.endpoint == other.endpoint \
and self.role_type == other.role_type \
and self.static_members == other.static_members \
and self.excluded_members == other.excluded_members
def __hash__(self):
return hash((self.endpoint_id, self.cluster_id, self.endpoint, self.role_type))
def __str__(self):
return (f"CustomEndpointInfo[endpoint={self.endpoint}, cluster_id={self.cluster_id}, "
f"role_type={self.role_type}, endpoint_id={self.endpoint_id}, static_members={self.static_members}, "
f"excluded_members={self.excluded_members}]")
class CustomEndpointMonitor:
"""
A custom endpoint monitor. This class uses a background thread to monitor a given custom endpoint for custom
endpoint information and future changes to the custom endpoint.
"""
_CUSTOM_ENDPOINT_INFO_EXPIRATION_NS: ClassVar[int] = 5 * 60_000_000_000 # 5 minutes
# Keys are custom endpoint URLs, values are information objects for the associated custom endpoint.
_custom_endpoint_info_cache: ClassVar[CacheMap[str, CustomEndpointInfo]] = CacheMap()
def __init__(self,
plugin_service: PluginService,
custom_endpoint_host_info: HostInfo,
endpoint_id: str,
region: str,
refresh_rate_ns: int,
session: Optional[Session] = None):
self._plugin_service = plugin_service
self._custom_endpoint_host_info = custom_endpoint_host_info
self._endpoint_id = endpoint_id
self._region = region
self._refresh_rate_ns = refresh_rate_ns
self._session = session if session else Session()
self._client = self._session.client('rds', region_name=region)
self._stop_event = Event()
telemetry_factory = self._plugin_service.get_telemetry_factory()
self._info_changed_counter = telemetry_factory.create_counter("customEndpoint.infoChanged.counter")
self._thread = Thread(daemon=True, name="CustomEndpointMonitorThread", target=self._run)
self._thread.start()
def _run(self):
logger.debug("CustomEndpointMonitor.StartingMonitor", self._custom_endpoint_host_info.host)
try:
while not self._stop_event.is_set():
try:
start_ns = perf_counter_ns()
response = self._client.describe_db_cluster_endpoints(
DBClusterEndpointIdentifier=self._endpoint_id,
Filters=[
{
"Name": "db-cluster-endpoint-type",
"Values": ["custom"]
}
]
)
endpoints = response["DBClusterEndpoints"]
if len(endpoints) != 1:
endpoint_hostnames = [endpoint["Endpoint"] for endpoint in endpoints]
logger.warning(
"CustomEndpointMonitor.UnexpectedNumberOfEndpoints",
self._endpoint_id,
self._region,
len(endpoints),
endpoint_hostnames)
sleep(self._refresh_rate_ns / 1_000_000_000)
continue
endpoint_info = CustomEndpointInfo.from_db_cluster_endpoint(endpoints[0])
cached_info = \
CustomEndpointMonitor._custom_endpoint_info_cache.get(self._custom_endpoint_host_info.host)
if cached_info is not None and cached_info == endpoint_info:
elapsed_time = perf_counter_ns() - start_ns
sleep_duration = max(0, self._refresh_rate_ns - elapsed_time)
sleep(sleep_duration / 1_000_000_000)
continue
logger.debug(
"CustomEndpointMonitor.DetectedChangeInCustomEndpointInfo",
self._custom_endpoint_host_info.host, endpoint_info)
# The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts.
hosts = AllowedAndBlockedHosts(endpoint_info.static_members, endpoint_info.excluded_members)
self._plugin_service.allowed_and_blocked_hosts = hosts
CustomEndpointMonitor._custom_endpoint_info_cache.put(
self._custom_endpoint_host_info.host,
endpoint_info,
CustomEndpointMonitor._CUSTOM_ENDPOINT_INFO_EXPIRATION_NS)
self._info_changed_counter.inc()
elapsed_time = perf_counter_ns() - start_ns
sleep_duration = max(0, self._refresh_rate_ns - elapsed_time)
sleep(sleep_duration / 1_000_000_000)
continue
except InterruptedError as e:
raise e
except Exception as e:
# If the exception is not an InterruptedError, log it and continue monitoring.
logger.error("CustomEndpointMonitor.Exception", self._custom_endpoint_host_info.host, e)
except InterruptedError:
logger.info("CustomEndpointMonitor.Interrupted", self._custom_endpoint_host_info.host)
finally:
CustomEndpointMonitor._custom_endpoint_info_cache.remove(self._custom_endpoint_host_info.host)
self._stop_event.set()
self._client.close()
logger.debug("CustomEndpointMonitor.StoppedMonitor", self._custom_endpoint_host_info.host)
def has_custom_endpoint_info(self):
return CustomEndpointMonitor._custom_endpoint_info_cache.get(self._custom_endpoint_host_info.host) is not None
def close(self):
logger.debug("CustomEndpointMonitor.StoppingMonitor", self._custom_endpoint_host_info.host)
CustomEndpointMonitor._custom_endpoint_info_cache.remove(self._custom_endpoint_host_info.host)
self._stop_event.set()
class CustomEndpointPlugin(Plugin):
"""
A plugin that analyzes custom endpoints for custom endpoint information and custom endpoint changes, such as adding
or removing an instance in the custom endpoint.
"""
_SUBSCRIBED_METHODS: ClassVar[Set[str]] = {"connect"}
_CACHE_CLEANUP_RATE_NS: ClassVar[int] = 6 * 10 ^ 10 # 1 minute
_monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, CustomEndpointMonitor]] = \
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_RATE_NS,
should_dispose_func=lambda _: True,
item_disposal_func=lambda monitor: monitor.close())
def __init__(self, plugin_service: PluginService, props: Properties):
self._plugin_service = plugin_service
self._props = props
self._should_wait_for_info: bool = WrapperProperties.WAIT_FOR_CUSTOM_ENDPOINT_INFO.get_bool(self._props)
self._wait_for_info_timeout_ms: int = WrapperProperties.WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.get_int(self._props)
self._idle_monitor_expiration_ms: int = \
WrapperProperties.CUSTOM_ENDPOINT_IDLE_MONITOR_EXPIRATION_MS.get_int(self._props)
self._rds_utils = RdsUtils()
self._region_utils = RegionUtils()
self._region: Optional[str] = None
self._custom_endpoint_host_info: Optional[HostInfo] = None
self._custom_endpoint_id: Optional[str] = None
telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory()
self._wait_for_info_counter: TelemetryCounter = telemetry_factory.create_counter("customEndpoint.waitForInfo.counter")
CustomEndpointPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods)
@property
def subscribed_methods(self) -> Set[str]:
return CustomEndpointPlugin._SUBSCRIBED_METHODS
def connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
if not self._rds_utils.is_rds_custom_cluster_dns(host_info.host):
return connect_func()
self._custom_endpoint_host_info = host_info
logger.debug("CustomEndpointPlugin.ConnectionRequestToCustomEndpoint", host_info.host)
self._custom_endpoint_id = self._rds_utils.get_cluster_id(host_info.host)
if not self._custom_endpoint_id:
raise AwsWrapperError(
Messages.get_formatted(
"CustomEndpointPlugin.ErrorParsingEndpointIdentifier", self._custom_endpoint_host_info.host))
hostname = self._custom_endpoint_host_info.host
self._region = self._region_utils.get_region_from_hostname(hostname)
if not self._region:
error_message = "RdsUtils.UnsupportedHostname"
logger.debug(error_message, hostname)
raise AwsWrapperError(Messages.get_formatted(error_message, hostname))
monitor = self._create_monitor_if_absent(props)
if self._should_wait_for_info:
self._wait_for_info(monitor)
return connect_func()
def _create_monitor_if_absent(self, props: Properties) -> CustomEndpointMonitor:
host_info = cast('HostInfo', self._custom_endpoint_host_info)
endpoint_id = cast('str', self._custom_endpoint_id)
region = cast('str', self._region)
monitor = CustomEndpointPlugin._monitors.compute_if_absent(
host_info.host,
lambda key: CustomEndpointMonitor(
self._plugin_service,
host_info,
endpoint_id,
region,
WrapperProperties.CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS.get_int(props) * 1_000_000),
self._idle_monitor_expiration_ms * 1_000_000)
return cast('CustomEndpointMonitor', monitor)
def _wait_for_info(self, monitor: CustomEndpointMonitor):
has_info = monitor.has_custom_endpoint_info()
if has_info:
return
self._wait_for_info_counter.inc()
host_info = cast('HostInfo', self._custom_endpoint_host_info)
hostname = host_info.host
logger.debug("CustomEndpointPlugin.WaitingForCustomEndpointInfo", hostname, self._wait_for_info_timeout_ms)
wait_for_info_timeout_ns = perf_counter_ns() + self._wait_for_info_timeout_ms * 1_000_000
try:
while not has_info and perf_counter_ns() < wait_for_info_timeout_ns:
sleep(0.1)
has_info = monitor.has_custom_endpoint_info()
except InterruptedError:
raise AwsWrapperError(Messages.get_formatted("CustomEndpointPlugin.InterruptedThread", hostname))
if not has_info:
raise AwsWrapperError(
Messages.get_formatted(
"CustomEndpointPlugin.TimedOutWaitingForCustomEndpointInfo",
self._wait_for_info_timeout_ms, hostname))
def execute(self, target: type, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any:
if self._custom_endpoint_host_info is None:
return execute_func()
monitor = self._create_monitor_if_absent(self._props)
if self._should_wait_for_info:
self._wait_for_info(monitor)
return execute_func()
class CustomEndpointPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return CustomEndpointPlugin(plugin_service, props)