alibabacloud_credentials/provider/refreshable.py (213 lines of code) (raw):

import random import asyncio import threading import logging import time import atexit from datetime import datetime from enum import Enum from typing import Callable, Generic, TypeVar, Coroutine, Any from threading import Semaphore from concurrent.futures.thread import ThreadPoolExecutor from alibabacloud_credentials.exceptions import CredentialException from alibabacloud_credentials_api import ICredentials log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) T = TypeVar('T') INT64_MAX = 2 ** 63 - 1 MAX_CONCURRENT_REFRESHES = 100 CONCURRENT_REFRESH_LEASES = Semaphore(MAX_CONCURRENT_REFRESHES) EXECUTOR = ThreadPoolExecutor(max_workers=INT64_MAX, thread_name_prefix='non-blocking-refresh') def _shutdown_handler(): log.debug("Shutting down executor...") EXECUTOR.shutdown(wait=False) atexit.register(_shutdown_handler) def _jitter_time(now: int, jitter_start: int, jitter_end: int) -> int: jitter_amount = random.randint(jitter_start, jitter_end) return now + jitter_amount def _max_stale_failure_jitter(num_failures: int) -> int: backoff_millis = max(10 * 1000, (1 << num_failures - 1) * 100) return backoff_millis class Credentials(ICredentials): def __init__(self, *, access_key_id: str = None, access_key_secret: str = None, security_token: str = None, expiration: int = None, provider_name: str = None): self._access_key_id = access_key_id self._access_key_secret = access_key_secret self._security_token = security_token self._expiration = expiration self._provider_name = provider_name def get_access_key_id(self) -> str: return self._access_key_id def get_access_key_secret(self) -> str: return self._access_key_secret def get_security_token(self) -> str: return self._security_token def get_expiration(self) -> int: return self._expiration def get_provider_name(self) -> str: return self._provider_name class StaleValueBehavior(Enum): """ Strictly treat the stale time. Never return a stale cached value (except when the supplier returns an expired value, in which case the supplier will return the value but only for a very short period of time to prevent overloading the underlying supplier). """ STRICT = 0 """ Allow stale values to be returned from the cache. Value retrieval will never fail, as long as the cache has succeeded when calling the underlying supplier at least once. """ ALLOW = 1 class RefreshResult(Generic[T]): def __init__(self, *, value: T, stale_time: int = INT64_MAX, prefetch_time: int = INT64_MAX): self._value = value self._stale_time = stale_time self._prefetch_time = prefetch_time def value(self) -> T: return self._value def stale_time(self) -> int: return self._stale_time def prefetch_time(self) -> int: return self._prefetch_time class PrefetchStrategy: def prefetch(self, action: Callable): raise NotImplementedError async def prefetch_async(self, action: Callable): raise NotImplementedError class NonBlocking(PrefetchStrategy): def prefetch(self, action: Callable): if not CONCURRENT_REFRESH_LEASES.acquire(False): log.warning('Skipping a background refresh task because there are too many other tasks running.') return try: EXECUTOR.submit(action) except KeyboardInterrupt: _shutdown_handler() except Exception as t: log.warning(f'Exception occurred when submitting background task.', exc_info=True) finally: CONCURRENT_REFRESH_LEASES.release() async def prefetch_async(self, action: Callable): def run_asyncio_loop(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(action()) loop.close() self.prefetch(run_asyncio_loop) class OneCallerBlocks(PrefetchStrategy): def prefetch(self, action: Callable): action() async def prefetch_async(self, action: Callable): await action() class RefreshCachedSupplier(Generic[T]): STALE_TIME = 15 * 60 # seconds REFRESH_BLOCKING_MAX_WAIT = 5 # seconds def __init__(self, refresh_callable: Callable[[], RefreshResult[T]], refresh_callable_async: Callable[[], Coroutine[Any, Any, RefreshResult[T]]], stale_value_behavior: StaleValueBehavior = StaleValueBehavior.STRICT, prefetch_strategy: PrefetchStrategy = OneCallerBlocks()): self._refresh_callable = refresh_callable self._refresh_callable_async = refresh_callable_async self._stale_value_behavior = stale_value_behavior self._prefetch_strategy = prefetch_strategy self._consecutive_refresh_failures = 0 self._cached_value = None self._refresh_lock = threading.Lock() def _sync_call(self) -> T: if self._cache_is_stale(): log.debug('Refreshing synchronously') self._refresh_cache() elif self._should_initiate_cache_prefetch(): log.debug(f'Prefetching using strategy: {self._prefetch_strategy.__class__.__name__}') self._prefetch_cache() return self._cached_value.value() async def _async_call(self) -> T: if self._cache_is_stale(): log.debug('Refreshing synchronously') await self._refresh_cache_async() elif self._should_initiate_cache_prefetch(): log.debug(f'Prefetching using strategy: {self._prefetch_strategy.__class__.__name__}') await self._prefetch_cache_async() return self._cached_value.value() def _cache_is_stale(self) -> bool: if self._cached_value is None: return True return int(time.mktime(time.localtime())) >= self._cached_value.stale_time() def _should_initiate_cache_prefetch(self) -> bool: if self._cached_value is None: return True return int(time.mktime(time.localtime())) >= self._cached_value.prefetch_time() def _prefetch_cache(self): self._prefetch_strategy.prefetch(self._refresh_cache) def _refresh_cache(self): acquired = self._refresh_lock.acquire(timeout=RefreshCachedSupplier.REFRESH_BLOCKING_MAX_WAIT) try: if self._cache_is_stale() or self._should_initiate_cache_prefetch(): try: self._cached_value = self._handle_fetched_success(self._refresh_callable()) except Exception as ex: self._cached_value = self._handle_fetched_failure(ex) finally: if acquired: self._refresh_lock.release() async def _prefetch_cache_async(self): await self._prefetch_strategy.prefetch_async(self._refresh_cache_async) async def _refresh_cache_async(self): acquired = self._refresh_lock.acquire(timeout=RefreshCachedSupplier.REFRESH_BLOCKING_MAX_WAIT) try: if self._cache_is_stale() or self._should_initiate_cache_prefetch(): try: self._cached_value = self._handle_fetched_success(await self._refresh_callable_async()) except Exception as ex: self._cached_value = self._handle_fetched_failure(ex) finally: if acquired: self._refresh_lock.release() def _handle_fetched_success(self, value: RefreshResult[T]) -> RefreshResult[T]: log.debug(f'Refresh credentials successfully, retrieved value is {value}, cached value is {self._cached_value}') self._consecutive_refresh_failures = 0 now = int(time.mktime(time.localtime())) # 过期时间大于15分钟,不用管 if now < value.stale_time(): log.debug( f'Retrieved value stale time is {datetime.fromtimestamp(value.stale_time())}. Using staleTime of {datetime.fromtimestamp(value.stale_time())}') return value # 不足或等于15分钟,但未过期,下次会再次刷新 if now < value.stale_time() + RefreshCachedSupplier.STALE_TIME: log.warning( f'Retrieved value stale time is in the past ({datetime.fromtimestamp(value.stale_time())}). Using staleTime of {datetime.fromtimestamp(now)}') return RefreshResult(value=value.value(), stale_time=now, prefetch_time=value.prefetch_time()) log.warning( f'Retrieved value expiration time of the credential is in the past ({datetime.fromtimestamp(value.stale_time() + RefreshCachedSupplier.STALE_TIME)}). Trying use the cached value.') # 已过期,看缓存,缓存若大于15分钟,返回缓存,若小于15分钟,则根据策略判断是立刻重试还是稍后重试 if self._cached_value is None: raise CredentialException('No cached value was found.') elif now < self._cached_value.stale_time(): log.warning( f'Cached value staleTime is {datetime.fromtimestamp(self._cached_value.stale_time())}. Using staleTime of {datetime.fromtimestamp(self._cached_value.stale_time())}') return self._cached_value elif self._stale_value_behavior == StaleValueBehavior.STRICT: log.warning( f'Cached value expiration is in the past ({datetime.fromtimestamp(self._cached_value.stale_time())}). Using expiration of {datetime.fromtimestamp(now + 1)}') return RefreshResult(value=self._cached_value.value(), stale_time=now + 1, prefetch_time=self._cached_value.prefetch_time()) else: # ALLOW extended_stale_time = now + int((50 * 1000 + random.randint(0, 20 * 1000 + 1)) / 1000) log.warning( f'Cached value expiration has been extended to {datetime.fromtimestamp(extended_stale_time)} because the downstream service returned a time in the past: {datetime.fromtimestamp(self._cached_value.stale_time())}') return RefreshResult(value=self._cached_value.value(), stale_time=extended_stale_time, prefetch_time=self._cached_value.prefetch_time()) def _handle_fetched_failure(self, exception: Exception) -> RefreshResult[T]: log.warning(f'Refresh credentials failed, cached value is {self._cached_value}, error: {exception}') if not self._cached_value: log.exception(exception) raise exception now = int(time.mktime(time.localtime())) if now < self._cached_value.stale_time(): return self._cached_value self._consecutive_refresh_failures += 1 if self._stale_value_behavior == StaleValueBehavior.STRICT: log.exception(exception) raise exception else: # ALLOW new_stale_time = int( _jitter_time(now * 1000, 1000, _max_stale_failure_jitter(self._consecutive_refresh_failures)) / 1000) log.warning( f'Cached value expiration has been extended to {datetime.fromtimestamp(new_stale_time)} because calling the downstream service failed (consecutive failures: {self._consecutive_refresh_failures}).') return RefreshResult(value=self._cached_value.value(), stale_time=new_stale_time, prefetch_time=self._cached_value.prefetch_time())