aws_advanced_python_wrapper/utils/sliding_expiration_cache.py (108 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from concurrent.futures import Executor, ThreadPoolExecutor from time import perf_counter_ns, sleep from typing import Callable, Generic, ItemsView, KeysView, Optional, TypeVar from aws_advanced_python_wrapper.utils.atomic import AtomicInt from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict from aws_advanced_python_wrapper.utils.log import Logger K = TypeVar('K') V = TypeVar('V') logger = Logger(__name__) class SlidingExpirationCache(Generic[K, V]): def __init__( self, cleanup_interval_ns: int = 10 * 60_000_000_000, # 10 minutes should_dispose_func: Optional[Callable] = None, item_disposal_func: Optional[Callable] = None): self._cleanup_interval_ns = cleanup_interval_ns self._should_dispose_func = should_dispose_func self._item_disposal_func = item_disposal_func self._cdict: ConcurrentDict[K, CacheItem[V]] = ConcurrentDict() self._cleanup_time_ns: AtomicInt = AtomicInt(perf_counter_ns() + self._cleanup_interval_ns) def __len__(self): return len(self._cdict) def set_cleanup_interval_ns(self, interval_ns): self._cleanup_interval_ns = interval_ns def keys(self) -> KeysView: return self._cdict.keys() def items(self) -> ItemsView: return self._cdict.items() def compute_if_absent(self, key: K, mapping_func: Callable, item_expiration_ns: int) -> Optional[V]: self._cleanup() cache_item = self._cdict.compute_if_absent( key, lambda k: CacheItem(mapping_func(k), perf_counter_ns() + item_expiration_ns)) return None if cache_item is None else cache_item.update_expiration(item_expiration_ns).item def get(self, key: K) -> Optional[V]: self._cleanup() cache_item = self._cdict.get(key) return cache_item.item if cache_item is not None else None def remove(self, key: K): self._remove_and_dispose(key) self._cleanup() def _remove_and_dispose(self, key: K): cache_item = self._cdict.remove(key) if cache_item is not None and self._item_disposal_func is not None: self._item_disposal_func(cache_item.item) def _remove_if_expired(self, key: K): item = None def _remove_if_expired_internal(_, cache_item): if self._should_cleanup_item(cache_item): nonlocal item item = cache_item.item return None return cache_item self._cdict.compute_if_present(key, _remove_if_expired_internal) if item is None or self._item_disposal_func is None: return self._item_disposal_func(item) def _should_cleanup_item(self, cache_item: CacheItem) -> bool: if self._should_dispose_func is not None: return perf_counter_ns() > cache_item.expiration_time and self._should_dispose_func(cache_item.item) return perf_counter_ns() > cache_item.expiration_time def clear(self): for _, cache_item in self._cdict.items(): if cache_item is not None and self._item_disposal_func is not None: self._item_disposal_func(cache_item.item) self._cdict.clear() def _cleanup(self): current_time = perf_counter_ns() if self._cleanup_time_ns.get() > current_time: return self._cleanup_time_ns.set(current_time + self._cleanup_interval_ns) keys = [key for key, _ in self._cdict.items()] for key in keys: self._remove_if_expired(key) class SlidingExpirationCacheWithCleanupThread(SlidingExpirationCache, Generic[K, V]): def __init__( self, cleanup_interval_ns: int = 10 * 60_000_000_000, # 10 minutes should_dispose_func: Optional[Callable] = None, item_disposal_func: Optional[Callable] = None): super().__init__(cleanup_interval_ns, should_dispose_func, item_disposal_func) self._executor: Executor = ThreadPoolExecutor(thread_name_prefix="SlidingExpirationCacheWithCleanupThreadExecutor") self.init_cleanup_thread() def init_cleanup_thread(self) -> None: self._executor.submit(self._cleanup_thread_internal) def _cleanup_thread_internal(self): logger.debug("SlidingExpirationCache.CleaningUp") current_time = perf_counter_ns() sleep(self._cleanup_interval_ns / 1_000_000_000) self._cleanup_time_ns.set(current_time + self._cleanup_interval_ns) keys = [key for key, _ in self._cdict.items()] for key in keys: try: self._remove_if_expired(key) except Exception: pass # ignore self._executor.shutdown() def _cleanup(self): pass # do nothing, cleanup thread does the job class CacheItem(Generic[V]): def __init__(self, item: V, expiration_time: int): self.item = item self.expiration_time = expiration_time def __str__(self): return f"CacheItem [item={str(self.item)}, expiration_time={self.expiration_time}]" def update_expiration(self, expiration_interval_ns: int) -> CacheItem: self.expiration_time = perf_counter_ns() + expiration_interval_ns return self