aws_advanced_python_wrapper/sql_alchemy_connection_provider.py (137 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, ClassVar, Dict, Optional, Tuple
if TYPE_CHECKING:
from aws_advanced_python_wrapper.database_dialect import DatabaseDialect
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from sqlalchemy import QueuePool, pool
from aws_advanced_python_wrapper.connection_provider import ConnectionProvider
from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.host_selector import (HostSelector,
RandomHostSelector,
RoundRobinHostSelector)
from aws_advanced_python_wrapper.plugin import CanReleaseResources
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
SlidingExpirationCache
class SqlAlchemyPooledConnectionProvider(ConnectionProvider, CanReleaseResources):
"""
This class can be passed to :py:method:`ConnectionProviderManager.connection_provider` to enable internal connection pools
for each database instance in a cluster. By maintaining internal connection pools,
the driver can improve performance by reusing old connection objects.
"""
_POOL_EXPIRATION_CHECK_NS: ClassVar[int] = 30 * 60_000_000_000 # 30 minutes
_LEAST_CONNECTIONS: ClassVar[str] = "least_connections"
_accepted_strategies: Dict[str, HostSelector] = {"random": RandomHostSelector(), "round_robin": RoundRobinHostSelector()}
_rds_utils: ClassVar[RdsUtils] = RdsUtils()
_database_pools: ClassVar[SlidingExpirationCache[PoolKey, QueuePool]] = SlidingExpirationCache(
should_dispose_func=lambda queue_pool: queue_pool.checkedout() == 0,
item_disposal_func=lambda queue_pool: queue_pool.dispose()
)
def __init__(
self,
pool_configurator: Optional[Callable] = None,
pool_mapping: Optional[Callable] = None,
accept_url_func: Optional[Callable] = None,
pool_expiration_check_ns: int = -1,
pool_cleanup_interval_ns: int = -1):
self._pool_configurator = pool_configurator
self._pool_mapping = pool_mapping
self._accept_url_func = accept_url_func
if pool_expiration_check_ns > -1:
SqlAlchemyPooledConnectionProvider._POOL_EXPIRATION_CHECK_NS = pool_expiration_check_ns
if pool_cleanup_interval_ns > -1:
SqlAlchemyPooledConnectionProvider._database_pools.set_cleanup_interval_ns(pool_cleanup_interval_ns)
@property
def num_pools(self):
return len(self._database_pools)
@property
def pool_urls(self):
return {pool_key.url for pool_key, _ in self._database_pools.items()}
def keys(self):
return self._database_pools.keys()
def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool:
if self._accept_url_func:
return self._accept_url_func(host_info, props)
url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host)
return RdsUrlType.RDS_INSTANCE == url_type
def accepts_strategy(self, role: HostRole, strategy: str) -> bool:
return strategy == SqlAlchemyPooledConnectionProvider._LEAST_CONNECTIONS or strategy in self._accepted_strategies
def get_host_info_by_strategy(self, hosts: Tuple[HostInfo, ...], role: HostRole, strategy: str, props: Optional[Properties]) -> HostInfo:
if not self.accepts_strategy(role, strategy):
raise AwsWrapperError(Messages.get_formatted(
"ConnectionProvider.UnsupportedHostSelectorStrategy",
strategy, SqlAlchemyPooledConnectionProvider.__class__.__name__))
if strategy == SqlAlchemyPooledConnectionProvider._LEAST_CONNECTIONS:
valid_hosts = [host for host in hosts if host.role == role]
valid_hosts.sort(key=lambda host: self._num_connections(host))
if len(valid_hosts) == 0:
raise AwsWrapperError(Messages.get_formatted("HostSelector.NoHostsMatchingRole", role))
return valid_hosts[0]
return self._accepted_strategies[strategy].get_host(hosts, role, props)
def _num_connections(self, host_info: HostInfo) -> int:
"""
Returns the number of active pooled connections to a specific host.
:param host_info: the host to analyze.
:return: number of connections opened in the connection pool to the given host.
"""
num_connections = 0
for pool_key, cache_item in SqlAlchemyPooledConnectionProvider._database_pools.items():
if pool_key.url == host_info.url:
num_connections += cache_item.item.checkedout()
return num_connections
def connect(
self,
target_func: Callable,
driver_dialect: DriverDialect,
database_dialect: DatabaseDialect,
host_info: HostInfo,
props: Properties):
queue_pool: Optional[QueuePool] = SqlAlchemyPooledConnectionProvider._database_pools.compute_if_absent(
PoolKey(host_info.url, self._get_extra_key(host_info, props)),
lambda _: self._create_pool(target_func, driver_dialect, database_dialect, host_info, props),
SqlAlchemyPooledConnectionProvider._POOL_EXPIRATION_CHECK_NS
)
if queue_pool is None:
raise AwsWrapperError(Messages.get_formatted("SqlAlchemyPooledConnectionProvider.PoolNone", host_info.url))
return queue_pool.connect()
# The pool key should always be retrieved using this method, because the username
# must always be included to avoid sharing privileged connections with other users.
def _get_extra_key(self, host_info: HostInfo, props: Properties) -> str:
if self._pool_mapping is not None:
return self._pool_mapping(host_info, props)
# Otherwise use default map key
user = props.get(WrapperProperties.USER.name, None)
if user is None or user == "":
raise AwsWrapperError(Messages.get("SqlAlchemyPooledConnectionProvider.UnableToCreateDefaultKey"))
return user
def _create_pool(
self,
target_func: Callable,
driver_dialect: DriverDialect,
database_dialect: DatabaseDialect,
host_info: HostInfo,
props: Properties):
kwargs = dict() if self._pool_configurator is None else self._pool_configurator(host_info, props)
prepared_properties = driver_dialect.prepare_connect_info(host_info, props)
database_dialect.prepare_conn_props(prepared_properties)
kwargs["creator"] = self._get_connection_func(target_func, prepared_properties)
return self._create_sql_alchemy_pool(**kwargs)
def _get_connection_func(self, target_connect_func: Callable, props: Properties):
return lambda: target_connect_func(**props)
def _create_sql_alchemy_pool(self, **kwargs):
return pool.QueuePool(**kwargs)
def release_resources(self):
for _, cache_item in SqlAlchemyPooledConnectionProvider._database_pools.items():
cache_item.item.dispose()
SqlAlchemyPooledConnectionProvider._database_pools.clear()
class PoolKey:
def __init__(self, url: str, extra_key: Optional[str] = None):
self._url = url
self._extra_key = extra_key
def __eq__(self, other):
if isinstance(other, type(self)):
return self.__members() == other.__members()
else:
return False
def __hash__(self):
return hash(self.__members())
def __members(self) -> Tuple[str, Optional[str]]:
return self._url, self._extra_key
@property
def url(self):
return self._url
@property
def extra_key(self):
return self._extra_key