aws_advanced_python_wrapper/host_selector.py (130 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import random
from re import search
from typing import TYPE_CHECKING, Dict, List, Optional, Protocol, Tuple
from .host_availability import HostAvailability
if TYPE_CHECKING:
from .hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
from .pep249 import Error
from .utils.messages import Messages
from .utils.properties import Properties, WrapperProperties
class HostSelector(Protocol):
"""
Interface for a strategy defining how to pick a host from a list of hosts.
"""
def get_host(self, hosts: Tuple[HostInfo, ...], role: HostRole, props: Optional[Properties] = None) -> HostInfo:
...
class RandomHostSelector(HostSelector):
def get_host(self, hosts: Tuple[HostInfo, ...], role: HostRole, props: Optional[Properties] = None) -> HostInfo:
eligible_hosts = [host for host in hosts if host.role == role and host.get_availability() == HostAvailability.AVAILABLE]
if len(eligible_hosts) == 0:
raise Error(Messages.get("HostSelector.NoEligibleHost"))
return random.choice(eligible_hosts)
class RoundRobinClusterInfo:
_last_host: Optional[HostInfo] = None
_cluster_weights_dict: Dict[str, int] = {}
_default_weight: int = 1
_weight_counter: int = 0
@property
def last_host(self) -> Optional[HostInfo]:
return self._last_host
@last_host.setter
def last_host(self, value):
self._last_host = value
@property
def cluster_weights_dict(self) -> Dict[str, int]:
return self._cluster_weights_dict
@cluster_weights_dict.setter
def cluster_weights_dict(self, value):
self._cluster_weights_dict = value
@property
def default_weight(self):
return self._default_weight
@default_weight.setter
def default_weight(self, value):
self._default_weight = value
@property
def weight_counter(self) -> int:
return self._weight_counter
@weight_counter.setter
def weight_counter(self, value):
self._weight_counter = value
class RoundRobinHostSelector(HostSelector):
_DEFAULT_WEIGHT: int = 1
_DEFAULT_ROUND_ROBIN_CACHE_EXPIRE_NANOS = 60000000000 * 10 # 10 minutes
_HOST_WEIGHT_PAIRS_PATTERN = r"((?P<host>[^:/?#]*):(?P<weight>.*))"
_round_robin_cache: CacheMap[str, Optional[RoundRobinClusterInfo]] = CacheMap()
def get_host(self, hosts: Tuple[HostInfo, ...], role: HostRole, props: Optional[Properties] = None) -> HostInfo:
eligible_hosts: List[HostInfo] = [host for host in hosts if host.role == role and host.get_availability() == HostAvailability.AVAILABLE]
eligible_hosts.sort(key=lambda host: host.host, reverse=False)
if len(eligible_hosts) == 0:
raise AwsWrapperError(Messages.get_formatted("HostSelector.NoHostsMatchingRole", role))
# Create new cache entries for provided hosts if necessary. All hosts point to the same cluster info.
self._create_cache_entry_for_hosts(eligible_hosts, props)
current_cluster_info_key: str = eligible_hosts[0].host
cluster_info: Optional[RoundRobinClusterInfo] = RoundRobinHostSelector._round_robin_cache.get(current_cluster_info_key)
last_host_index: int = -1
if cluster_info is None:
raise AwsWrapperError(Messages.get("RoundRobinHostSelector.ClusterInfoNone"))
last_host = cluster_info.last_host
# Check if last_host is in list of eligible hosts. Update last_host_index.
if last_host is not None:
for i in range(0, len(eligible_hosts)):
if eligible_hosts[i].host == last_host.host:
last_host_index = i
if cluster_info.weight_counter > 0 and last_host_index != -1:
target_host_index = last_host_index
else:
if last_host_index != -1 and last_host_index != (len(eligible_hosts) - 1):
target_host_index = last_host_index + 1
else:
target_host_index = 0
weight = cluster_info.cluster_weights_dict.get(eligible_hosts[target_host_index].host)
cluster_info.weight_counter = cluster_info.default_weight if weight is None else weight
cluster_info.weight_counter = (cluster_info.weight_counter - 1)
cluster_info.last_host = eligible_hosts[target_host_index]
return eligible_hosts[target_host_index]
def _create_cache_entry_for_hosts(self, hosts: List[HostInfo], props: Optional[Properties]) -> None:
cached_info = None
for host in hosts:
info = self._round_robin_cache.get(host.host)
if info is not None:
cached_info = info
break
if cached_info is not None:
for host in hosts:
# Update the expiration time
self._round_robin_cache.put(
host.host, cached_info, RoundRobinHostSelector._DEFAULT_ROUND_ROBIN_CACHE_EXPIRE_NANOS)
else:
round_robin_cluster_info: RoundRobinClusterInfo = RoundRobinClusterInfo()
self._update_cache_properties_for_round_robin_cluster_info(round_robin_cluster_info, props)
for host in hosts:
self._round_robin_cache.put(
host.host, round_robin_cluster_info, RoundRobinHostSelector._DEFAULT_ROUND_ROBIN_CACHE_EXPIRE_NANOS)
def _update_cache_properties_for_round_robin_cluster_info(self, round_robin_cluster_info: RoundRobinClusterInfo, props: Optional[Properties]):
cluster_default_weight: int = RoundRobinHostSelector._DEFAULT_WEIGHT
if props is not None:
props_weight = WrapperProperties.ROUND_ROBIN_DEFAULT_WEIGHT.get_int(props)
if props_weight < RoundRobinHostSelector._DEFAULT_WEIGHT:
raise AwsWrapperError(Messages.get("RoundRobinHostSelector.RoundRobinInvalidDefaultWeight"))
cluster_default_weight = props_weight
round_robin_cluster_info.default_weight = cluster_default_weight
if props is not None:
host_weights: Optional[str] = WrapperProperties.ROUND_ROBIN_HOST_WEIGHT_PAIRS.get(props)
if host_weights is not None and len(host_weights) != 0:
host_weight_pairs: List[str] = host_weights.split(",")
for pair in host_weight_pairs:
match = search(RoundRobinHostSelector._HOST_WEIGHT_PAIRS_PATTERN, pair)
if match:
host_name = match.group("host")
host_weight = match.group("weight")
else:
raise AwsWrapperError(Messages.get("RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs"))
if len(host_name) == 0 or len(host_weight) == 0:
raise AwsWrapperError(Messages.get("RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs"))
try:
weight: int = int(host_weight)
if weight < RoundRobinHostSelector._DEFAULT_WEIGHT:
raise AwsWrapperError(Messages.get("RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs"))
round_robin_cluster_info.cluster_weights_dict[host_name] = weight
except ValueError:
raise AwsWrapperError(Messages.get("RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs"))
def clear_cache(self):
RoundRobinHostSelector._round_robin_cache.clear()