pathology/transformation_pipeline/ingestion_lib/mock_redis_client.py (155 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Mock Redis Client.""" import contextlib import dataclasses import threading import time from typing import Dict, Optional from unittest import mock import uuid from absl.testing import flagsaver import redis from pathology.transformation_pipeline.ingestion_lib import redis_client @dataclasses.dataclass class _MockRedisLockState: """Mock Redis Lock State.""" token: str = '' owned: bool = False expire_time: float = 0.0 class _MockRedis: """Mock Redis Client.""" def __init__( self, host: str, port: int, db: int = 0, username: Optional[str] = None, password: Optional[str] = None, ): del host, port, db, username, password self._connected = True self._redis_locks: Dict[str, _MockRedisLockState] = {} self.lock = threading.Lock() def ping(self) -> bool: return self.redis_server_connected def clean_expired_locks(self) -> None: """Removes expired locks.""" current_time = time.time() expired_locks = [ name for name, state in self._redis_locks.items() if state.expire_time < current_time ] for name in expired_locks: del self._redis_locks[name] @property def redis_server_connected(self) -> bool: return self._connected @redis_server_connected.setter def redis_server_connected(self, val: bool) -> None: with self.lock: self._connected = val def get_lock(self, name: str) -> Optional[_MockRedisLockState]: return self._redis_locks.get(name) def set_lock(self, name: str, state: _MockRedisLockState) -> None: self._redis_locks[name] = state def get_lock_expire_time(self, name: str) -> float: return self._redis_locks[name].expire_time class _MockRedisLock: """Mock Redis Lock.""" def __init__( self, instance: _MockRedis, name: str, timeout: int, thread_local: bool ): if not name: raise ValueError('Undefined name') if timeout <= 0: raise ValueError('Invalid timeout') if thread_local: raise ValueError( 'Redis instance running with thread local storage not supported by' ' mock.' ) self._name = name self._token = str(uuid.uuid4()) self._redis_instance = instance self._ttl = timeout with self._redis_instance.lock: self._redis_instance.clean_expired_locks() def extend(self, additional_time: int, replace_ttl: bool = False) -> None: """Extends the Redis Lock TTL.""" with self._redis_instance.lock: if not self._redis_instance.redis_server_connected: raise redis.exceptions.ConnectionError() if additional_time < 0: raise ValueError('Invalid lock extend amount.') self._redis_instance.clean_expired_locks() lock_state = self._redis_instance.get_lock(self._name) if ( lock_state is not None and lock_state.owned and lock_state.token == self._token ): if replace_ttl: lock_state.expire_time = time.time() + additional_time return raise ValueError( 'Cannot extend existing TTL is not supported by the Redis mock' ) raise ValueError('Cannot extend lock is not owned.', lock_state) def acquire(self, blocking: bool, token: str) -> bool: """Returns True if Redis Lock is acquired.""" with self._redis_instance.lock: if not self._redis_instance.redis_server_connected: raise redis.exceptions.ConnectionError() if blocking: raise ValueError('Acquired blocking lock not supported by mock.') if not token: raise ValueError('Undefined token.') self._redis_instance.clean_expired_locks() lock_state = self._redis_instance.get_lock(self._name) if ( lock_state is None or not lock_state.owned or lock_state.token == token ): self._token = token self._redis_instance.set_lock( self._name, _MockRedisLockState(token, True, self._ttl + time.time()), ) return True return False def release(self) -> None: """Releases the Redis Lock.""" with self._redis_instance.lock: if not self._redis_instance.redis_server_connected: raise redis.exceptions.ConnectionError() self._redis_instance.clean_expired_locks() lock_state = self._redis_instance.get_lock(self._name) if ( lock_state is not None and lock_state.owned and lock_state.token == self._token ): lock_state.owned = False return raise redis.exceptions.LockError('Cannot release an unlocked lock') def owned(self) -> bool: """Returns True if Redis Lock is ownend.""" with self._redis_instance.lock: if not self._redis_instance.redis_server_connected: raise redis.exceptions.ConnectionError() self._redis_instance.clean_expired_locks() lock_state = self._redis_instance.get_lock(self._name) return bool( lock_state is not None and lock_state.owned and lock_state.token == self._token ) class MockRedisClient(contextlib.ExitStack): """Creates Context manged block to init Redis Client Mock.""" def __init__( self, redis_server_ip: Optional[str] = None, pod_uid: str = 'MOCK_POD_UID' ): super().__init__() self.enter_context( mock.patch('redis.Redis', autospec=True, side_effect=_MockRedis) ) self.enter_context( mock.patch('redis.lock.Lock', autospec=True, side_effect=_MockRedisLock) ) self.enter_context( flagsaver.flagsaver( redis_server_ip=redis_server_ip, transform_pod_uid=pod_uid ) ) redis_client.RedisClient.init_fork_module_state()