perfkitbenchmarker/linux_packages/redis_enterprise.py (672 lines of code) (raw):

# Copyright 2019 PerfKitBenchmarker Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module containing redis enterprise installation and cleanup functions. TODO(user): Flags should be unified with memtier.py. """ import dataclasses import itertools import json import logging import posixpath from typing import Any, Dict, List, Tuple from absl import flags from perfkitbenchmarker import background_tasks from perfkitbenchmarker import data from perfkitbenchmarker import errors from perfkitbenchmarker import os_types from perfkitbenchmarker import sample from perfkitbenchmarker import virtual_machine from perfkitbenchmarker import vm_util import requests FLAGS = flags.FLAGS _LICENSE = flags.DEFINE_string( 'enterprise_redis_license_file', 'enterprise_redis_license', 'Name of the redis enterprise license file to use.', ) _LICENSE_PATH = flags.DEFINE_string( 'enterprise_redis_license_path', None, 'If none, defaults to local data directory joined with _LICENSE.', ) _TUNE_ON_STARTUP = flags.DEFINE_boolean( 'enterprise_redis_tune_on_startup', True, 'Whether to tune core config during startup.', ) _PROXY_THREADS = flags.DEFINE_integer( 'enterprise_redis_proxy_threads', None, 'Number of redis proxy threads to use.', ) _SHARDS = flags.DEFINE_integer( 'enterprise_redis_shard_count', None, 'Number of redis shards per database. Each shard is a redis thread.', ) _LOAD_RECORDS = flags.DEFINE_integer( 'enterprise_redis_load_records', 1000000, 'Number of keys to pre-load into Redis. Use' ' --enterprise_redis_data_size_bytes to calculate how much space will be' ' used. Due to overhead, a ballpark estimate is 1KB record takes 1.5KB of' ' utilization. See' ' https://docs.redis.com/latest/rs/concepts/memory-performance/for more' ' info.', ) _RUN_RECORDS = flags.DEFINE_integer( 'enterprise_redis_run_records', 1000000, 'Number of requests per loadgen client to send to the Redis server.', ) _PIPELINES = flags.DEFINE_integer( 'enterprise_redis_pipeline', 9, 'Number of pipelines to use.' ) _LOADGEN_CLIENTS = flags.DEFINE_integer( 'enterprise_redis_loadgen_clients', 24, 'Number of clients per loadgen vm.' ) _MAX_THREADS = flags.DEFINE_integer( 'enterprise_redis_max_threads', 40, 'Maximum number of memtier threads to use.', ) _MIN_THREADS = flags.DEFINE_integer( 'enterprise_redis_min_threads', 18, 'Minimum number of memtier threads to use.', ) _THREAD_INCREMENT = flags.DEFINE_integer( 'enterprise_redis_thread_increment', 1, 'Number of memtier threads to increment by.', ) _LATENCY_THRESHOLD = flags.DEFINE_integer( 'enterprise_redis_latency_threshold', 1100, 'The latency threshold in microseconds until the test stops.', ) _PIN_WORKERS = flags.DEFINE_boolean( 'enterprise_redis_pin_workers', False, 'Whether to pin the proxy threads after startup.', ) _DISABLE_CPU_IDS = flags.DEFINE_list( 'enterprise_redis_disable_cpu_ids', None, 'List of cpus to disable by id.' ) _DATA_SIZE = flags.DEFINE_integer( 'enterprise_redis_data_size_bytes', 100, 'The size of the data to write to redis enterprise.', ) _NUM_DATABASES = flags.DEFINE_integer( 'enterprise_redis_db_count', 1, 'The number of databases to create on the cluster.', ) _REPLICATION = flags.DEFINE_bool( 'enterprise_redis_db_replication', False, 'If true, replicates each database to another node. Doubles the amount of ' 'memory used by the database records.', ) _MEMORY_SIZE_PERCENTAGE = flags.DEFINE_float( 'enterprise_redis_memory_size_percentage', 0.80, 'The percentage amount of memory to use out of all the available memory ' 'reported by rladmin for provisioning databases. 1 means use all available ' 'memory, which in practice tends to be error-prone.', ) _VM = virtual_machine.VirtualMachine _ThroughputSampleTuple = Tuple[float, List[sample.Sample]] _ThroughputSampleMatrix = List[List[_ThroughputSampleTuple]] _Json = Dict[str, Any] _UidToJsonDict = Dict[int, _Json] _VERSION = '6.2.4-54' _PACKAGE_NAME = 'redis_enterprise' _WORKING_DIR = '~/redislabs' _RHEL_TAR = f'redislabs-{_VERSION}-rhel7-x86_64.tar' _XENIAL_TAR = f'redislabs-{_VERSION}-xenial-amd64.tar' _BIONIC_TAR = f'redislabs-{_VERSION}-bionic-amd64.tar' _USERNAME = 'user@google.com' _ONE_KILOBYTE = 1000 _ONE_MEGABYTE = _ONE_KILOBYTE * 1000 _ONE_GIGABYTE = _ONE_MEGABYTE * 1000 PREPROVISIONED_DATA = { # These checksums correspond to version 6.2.4-54. To update, run # 'sha256sum <redislabs-{VERSION}-rhel7-x86_64.tar>' and replace the values # below. _RHEL_TAR: ( 'fb0b7aa5f115eb0bc2ac4fb958aaa7ad92bb260f2251a221a15b01fbdf4d2d14' ), _XENIAL_TAR: ( 'f78a6bb486f3dfb3e5ba9b5be86b1c880d0c76a08eb0dc4bd3aaaf9cc210406d' ), _BIONIC_TAR: ( 'dfe568958b243368c1f1c08c9cce9f660fa06e1bce38fa88f90503e344466927' ), } _THREAD_OPTIMIZATION_RATIO = 0.75 def _GetTarName() -> str | None: """Returns the Redis Enterprise package to use depending on the os. For information about available packages, see https://redislabs.com/redis-enterprise/software/downloads/. """ if FLAGS.os_type in [os_types.RHEL, os_types.AMAZONLINUX2, os_types.CENTOS7]: return _RHEL_TAR if FLAGS.os_type in [os_types.UBUNTU1604, os_types.DEBIAN, os_types.DEBIAN9]: return _XENIAL_TAR if FLAGS.os_type == os_types.UBUNTU1804: return _BIONIC_TAR def Install(vm: _VM) -> None: """Installs Redis Enterprise package on the VM.""" vm.InstallPackages('wget') vm.RemoteCommand(f'mkdir -p {_WORKING_DIR}') # Check for the license in the data directory if a path isn't specified. license_path = _LICENSE_PATH.value if not license_path: license_path = data.ResourcePath(_LICENSE.value) vm.PushFile(license_path, posixpath.join(_WORKING_DIR, _LICENSE.value)) # Check for the tarfile in the data directory first. vm.InstallPreprovisionedPackageData( _PACKAGE_NAME, [_GetTarName()], _WORKING_DIR ) vm.RemoteCommand( 'cd {dir} && sudo tar xvf {tar}'.format( dir=_WORKING_DIR, tar=_GetTarName() ) ) if FLAGS.os_type == os_types.UBUNTU1804: # Fix Ubuntu 18.04 DNS conflict vm.RemoteCommand( 'echo "DNSStubListener=no" | sudo tee -a /etc/systemd/resolved.conf' ) vm.RemoteCommand('sudo mv /etc/resolv.conf /etc/resolv.conf.orig') vm.RemoteCommand( 'sudo ln -s /run/systemd/resolve/resolv.conf /etc/resolv.conf' ) vm.RemoteCommand('sudo service systemd-resolved restart') install_cmd = './install.sh -y' if not _TUNE_ON_STARTUP.value: install_cmd = 'CONFIG_systune=no ./install.sh -y -n' vm.RemoteCommand( 'cd {dir} && sudo {install}'.format(dir=_WORKING_DIR, install=install_cmd) ) def _JoinCluster(server_vm: _VM, vm: _VM) -> None: """Joins a Redis Enterprise cluster.""" logging.info('Joining redis enterprise cluster.') vm.RemoteCommand( 'sudo /opt/redislabs/bin/rladmin cluster join ' 'nodes {server_vm_ip} ' 'username {username} ' 'password {password} '.format( server_vm_ip=server_vm.internal_ip, username=_USERNAME, password=FLAGS.run_uri, ) ) def CreateCluster(vms: List[_VM]) -> None: """Creates a Redis Enterprise cluster on the VM.""" logging.info('Creating redis enterprise cluster.') vms[0].RemoteCommand( 'sudo /opt/redislabs/bin/rladmin cluster create ' 'license_file {license_file} ' 'name redis-cluster ' 'username {username} ' 'password {password} '.format( license_file=posixpath.join(_WORKING_DIR, _LICENSE.value), username=_USERNAME, password=FLAGS.run_uri, ) ) for vm in vms[1:]: _JoinCluster(vms[0], vm) def OfflineCores(vms: List[_VM]) -> None: """Offlines specific cores.""" def _Offline(vm): for cpu_id in _DISABLE_CPU_IDS.value or []: vm.RemoteCommand( 'sudo bash -c "echo 0 > /sys/devices/system/cpu/cpu%s/online"' % cpu_id ) background_tasks.RunThreaded(_Offline, vms) def TuneProxy(vm: _VM, proxy_threads: int | None = None) -> None: """Tunes the number of Redis proxies on the cluster.""" proxy_threads = proxy_threads or _PROXY_THREADS.value vm.RemoteCommand( 'sudo /opt/redislabs/bin/rladmin tune ' 'proxy all ' f'max_threads {proxy_threads} ' f'threads {proxy_threads} ' ) vm.RemoteCommand('sudo /opt/redislabs/bin/dmc_ctl restart') def PinWorkers(vms: List[_VM], proxy_threads: int | None = None) -> None: """Splits the Redis worker threads across the NUMA nodes evenly. This function is no-op if --enterprise_redis_pin_workers is not set. Args: vms: The VMs with the Redis workers to pin. proxy_threads: The number of proxy threads per VM. """ if not _PIN_WORKERS.value: return proxy_threads = proxy_threads or _PROXY_THREADS.value def _Pin(vm): numa_nodes = vm.CheckLsCpu().numa_node_count proxies_per_node = proxy_threads // numa_nodes for node in range(numa_nodes): node_cpu_list = vm.RemoteCommand( 'cat /sys/devices/system/node/node%d/cpulist' % node )[0].strip() # List the PIDs of the Redis worker processes and pin a sliding window of # `proxies_per_node` workers to the NUMA nodes in increasing order. vm.RemoteCommand( r'sudo /opt/redislabs/bin/dmc-cli -ts root list | ' r'grep worker | ' r'head -n -{proxies_already_partitioned} | ' r'tail -n {proxies_per_node} | ' r"awk '" r'{{printf "%i\n",$3}}' r"' | " r'xargs -i sudo taskset -pc {node_cpu_list} {{}} '.format( proxies_already_partitioned=proxies_per_node * node, proxies_per_node=proxies_per_node, node_cpu_list=node_cpu_list, ) ) background_tasks.RunThreaded(_Pin, vms) def GetDatabaseMemorySize(vm: _VM) -> int: """Gets the available memory (bytes) that can be used to provision databases.""" output, _ = vm.RemoteCommand('sudo /opt/redislabs/bin/rladmin status') # See tests/data/redis_enterprise_cluster_output.txt node_output = output.splitlines()[2] provisional_ram = node_output.split()[7] size_gb = float(provisional_ram.split('/')[0].strip('GB')) * ( _MEMORY_SIZE_PERCENTAGE.value ) return int(size_gb * _ONE_GIGABYTE) @dataclasses.dataclass(frozen=True) class LoadRequest: key_minimum: int key_maximum: int redis_port: int cluster_mode: bool server_ip: str def _BuildLoadCommand(request: LoadRequest) -> str: """Returns the command used to load the database.""" command = ( 'sudo /opt/redislabs/bin/memtier_benchmark ' f'-s {request.server_ip} ' f'-a {FLAGS.run_uri} ' f'-p {str(request.redis_port)} ' '-t 1 ' # Set -t and -c to 1 to avoid duplicated work in writing the same '-c 1 ' # key/value pairs repeatedly. '--ratio 1:0 ' '--pipeline 100 ' f'-d {str(_DATA_SIZE.value)} ' '--key-pattern S:S ' f'--key-minimum {request.key_minimum} ' f'--key-maximum {request.key_maximum} ' '-n allkeys ' ) if request.cluster_mode: command += '--cluster-mode' return command def _LoadDatabaseSingleVM(load_vm: _VM, request: LoadRequest) -> None: """Loads the DB from a single VM.""" command = _BuildLoadCommand(request) logging.info('Loading database with %s', request) load_vm.RemoteCommand(command) def LoadDatabases( redis_vms: List[_VM], load_vms: List[_VM], endpoints: List[Tuple[str, int]], shards: int | None = None, ) -> None: """Loads the databases before performing tests.""" vms = load_vms + redis_vms load_requests = [] load_records_per_vm = _LOAD_RECORDS.value // len(vms) shards = shards or _SHARDS.value cluster_mode = shards > 1 for endpoint, port in endpoints: for i, _ in enumerate(vms): load_requests.append(( vms[i], LoadRequest( key_minimum=max(i * load_records_per_vm, 1), key_maximum=(i + 1) * load_records_per_vm, redis_port=port, cluster_mode=cluster_mode, server_ip=endpoint, ), )) background_tasks.RunThreaded( _LoadDatabaseSingleVM, [(arg, {}) for arg in load_requests] ) class HttpClient: """HTTP Client for interacting with Redis REST API.""" def __init__(self, server_vms: List[_VM]): self.vms = server_vms self.api_base_url = f'https://{server_vms[0].ip_address}:9443' self.session = requests.Session() self.session.auth = (_USERNAME, FLAGS.run_uri) self.provisional_memory = 0 def _LogCurlifiedCommand(self, response: requests.Response) -> None: """Logs the version of the request that can be run from curl.""" request = response.request command = f'curl -v -k -u {_USERNAME}:{FLAGS.run_uri} -X {request.method} ' if request.body: body = request.body.decode('UTF-8') command += f'-d "{body}" -H "Content-type: application/json" ' command += f'{request.url}' logging.info('Making API call equivalent to this curl command: %s', command) def GetDatabases(self) -> _UidToJsonDict | None: """Gets the database object(s) running in the cluster. Returns: A dictionary of database objects by uid, per https://docs.redis.com/latest/rs/references/rest-api/objects/bdb/. Raises: errors.Benchmarks.RunError: If the number of databases found does not match --enterprise_redis_db_count. """ logging.info('Getting Redis Enterprise databases.') r = self.session.get(f'{self.api_base_url}/v1/bdbs', verify=False) self._LogCurlifiedCommand(r) results = r.json() return {result['uid']: result for result in results} def GetEndpoints(self) -> List[Tuple[str, int]]: """Returns a list of (ip, port) tuples.""" endpoints = [] for db in self.GetDatabases().values(): host = db['endpoints'][0]['addr'][0] port = db['endpoints'][0]['port'] endpoints.append((host, port)) logging.info('Database endpoints: %s', endpoints) return endpoints def GetDatabase(self, uid: int) -> _Json | None: """Returns the database JSON object corresponding to uid.""" logging.info('Getting Redis Enterprise database (uid: %s).', uid) all_databases = self.GetDatabases() return all_databases.get(uid, None) @vm_util.Retry( poll_interval=5, retryable_exceptions=(errors.Resource.RetryableCreationError,), ) def WaitForDatabaseUp(self, uid: int) -> None: """Waits for the Redis Enterprise database to become active.""" db = self.GetDatabase(uid) if not db: raise errors.Benchmarks.RunError( 'Database %s does not exist, expected to be waiting for startup.' ) if db['status'] != 'active': raise errors.Resource.RetryableCreationError() def CreateDatabase(self, shards: int | None = None) -> _Json: """Creates a new Redis Enterprise database. See https://docs.redis.com/latest/rs/references/rest-api/objects/bdb/. Args: shards: Number of shards for the database. In a clustered setup, shards will be distributed evenly across nodes. Returns: Returns the JSON object corresponding to the database that was created. """ db_shards = shards or _SHARDS.value if not self.provisional_memory: self.provisional_memory = GetDatabaseMemorySize(self.vms[0]) per_db_memory_size = int( self.provisional_memory * len(self.vms) / _NUM_DATABASES.value ) content = { 'name': 'redisdb', 'type': 'redis', 'memory_size': per_db_memory_size, 'proxy_policy': 'all-master-shards', 'sharding': False, 'authentication_redis_pass': FLAGS.run_uri, 'replication': _REPLICATION.value, } if db_shards > 1: content.update({ 'sharding': True, 'shards_count': db_shards, 'shards_placement': 'sparse', 'oss_cluster': True, 'shard_key_regex': [ {'regex': '.*\\{(?<tag>.*)\\}.*'}, {'regex': '(?<tag>.*)'}], }) # pyformat: disable logging.info('Creating Redis Enterprise database with %s.', content) r = self.session.post( f'{self.api_base_url}/v1/bdbs', json=content, verify=False ) self._LogCurlifiedCommand(r) if r.status_code != 200: raise errors.Benchmarks.RunError( f'Unable to create database: status code: {r.status_code}, ' f'reason {r.reason}.' ) self.WaitForDatabaseUp(r.json()['uid']) logging.info('Finished creating Redis Enterprise database %s.', r.json()) return r.json() def CreateDatabases(self, shards: int | None = None) -> None: """Creates all databases with the specified number of shards.""" for _ in range(_NUM_DATABASES.value): self.CreateDatabase(shards) @vm_util.Retry( poll_interval=5, retryable_exceptions=(errors.Resource.RetryableDeletionError,), ) def DeleteDatabase(self, uid: int) -> None: """Deletes the database from the cluster.""" logging.info('Deleting Redis Enterprise database (uid: %s).', uid) r = self.session.delete(f'{self.api_base_url}/v1/bdbs/{uid}', verify=False) self._LogCurlifiedCommand(r) if self.GetDatabase(uid) is not None: logging.info('Waiting for DB to finish deleting.') raise errors.Resource.RetryableDeletionError() def DeleteDatabases(self) -> None: """Deletes all databases.""" for uid in self.GetDatabases().keys(): self.DeleteDatabase(uid) def _BuildRunCommand( host: str, threads: int, port: int, shards: int | None = None ) -> str: """Spawns a memtier_benchmark on the load_vm against the redis_vm:port. Args: host: The target IP of the memtier_benchmark threads: The number of threads to run in this memtier_benchmark process. port: The port to target on the redis_vm. shards: The number of shards per database. Returns: Command to issue to the redis server. """ if threads == 0: return None result = ( 'sudo /opt/redislabs/bin/memtier_benchmark ' f'-s {host} ' f'-a {FLAGS.run_uri} ' f'-p {str(port)} ' f'-t {str(threads)} ' '--ratio 1:1 ' f'--pipeline {str(_PIPELINES.value)} ' f'-c {str(_LOADGEN_CLIENTS.value)} ' f'-d {str(_DATA_SIZE.value)} ' '--test-time 30 ' '--key-minimum 1 ' f'--key-maximum {str(_LOAD_RECORDS.value)} ' ) shards = shards or _SHARDS.value if shards > 1: result += '--cluster-mode ' return result @dataclasses.dataclass(frozen=True) class Result: """Individual throughput and latency result.""" throughput: int latency_usec: int metadata: Dict[str, Any] def ToSample(self, metadata: Dict[str, Any]) -> sample.Sample: """Returns throughput sample attached with the given metadata.""" self.metadata.update(metadata) return sample.Sample('throughput', self.throughput, 'ops/s', self.metadata) def ParseResults(output: str) -> List[Result]: """Parses the result from the cluster statistics API.""" output_json = json.loads(output) results = [] for interval in output_json.get('intervals'): results.append( Result(interval.get('total_req'), interval.get('avg_latency'), interval) ) return results def Run( redis_vms: List[_VM], load_vms: List[_VM], shards: int | None = None, proxy_threads: int | None = None, memtier_threads: int | None = None, ) -> _ThroughputSampleTuple: """Run memtier against enterprise redis and measure latency and throughput. This function runs memtier against the redis server vm with increasing memtier threads until one of following conditions is reached: - FLAGS.enterprise_redis_max_threads is reached - FLAGS.enterprise_redis_latency_threshold is reached Args: redis_vms: Redis server vms. load_vms: Memtier load vms. shards: The per-DB shard count for this run. proxy_threads: The per-VM proxy thread count for this run. memtier_threads: If provided, overrides --enterprise_redis_min_threads. Returns: A tuple of (max_throughput_under_1ms, list of sample.Sample objects). """ # TODO(liubrandon): Break up this function as it's getting rather long. results = [] cur_max_latency = 0.0 latency_threshold = _LATENCY_THRESHOLD.value shards = shards or _SHARDS.value threads = memtier_threads or _MIN_THREADS.value max_threads = _MAX_THREADS.value max_throughput_for_completion_latency_under_1ms = 0.0 client = HttpClient(redis_vms) endpoints = client.GetEndpoints() redis_vm = redis_vms[0] # Validate before running if len(set(endpoints)) < _NUM_DATABASES.value: raise errors.Benchmarks.RunError( f'Wrong number of unique endpoints {endpoints}, ' f'expected {_NUM_DATABASES.value}.' ) if threads > max_threads: raise errors.Benchmarks.RunError( 'min threads %s higher than max threads %s, ' 'raise --enterprise_redis_max_threads' ) while cur_max_latency < latency_threshold and threads <= max_threads: # Set up run commands run_cmds = [ _BuildRunCommand(endpoint, threads, port, shards) for endpoint, port in endpoints ] args = [(arg, {}) for arg in itertools.product(load_vms, run_cmds)] # 30 sec for throughput to stabilize and 10 sec of data. measurement_command = ( 'sleep 25 && curl -v -k -u {user}:{password} ' 'https://localhost:9443/v1/cluster/stats?interval=1sec > ~/output' .format( user=_USERNAME, password=FLAGS.run_uri, ) ) args += [((redis_vm, measurement_command), {})] # Run background_tasks.RunThreaded( lambda vm, command: vm.RemoteCommand(command), args ) stdout, _ = redis_vm.RemoteCommand('cat ~/output') # Parse results and iterate metadata = GetMetadata(shards, threads, proxy_threads) run_results = ParseResults(stdout) for result in run_results: results.append(result.ToSample(metadata)) latency = result.latency_usec cur_max_latency = max(cur_max_latency, latency) if latency < 1000: max_throughput_for_completion_latency_under_1ms = max( max_throughput_for_completion_latency_under_1ms, result.throughput ) logging.info( 'Threads : %d (%f ops/sec, %f ms latency) < %f ms latency', threads, result.throughput, latency, latency_threshold, ) threads += _THREAD_INCREMENT.value if cur_max_latency >= 1000: results.append( sample.Sample( 'max_throughput_for_completion_latency_under_1ms', max_throughput_for_completion_latency_under_1ms, 'ops/s', metadata, ) ) logging.info( 'Max throughput under 1ms: %s ops/sec.', max_throughput_for_completion_latency_under_1ms, ) return max_throughput_for_completion_latency_under_1ms, results def GetMetadata( shards: int, threads: int, proxy_threads: int ) -> Dict[str, Any]: """Returns metadata associated with the run. Args: shards: The shard count per database. threads: The thread count used by the memtier client. proxy_threads: The proxy thread count used on the redis cluster. Returns: A dictionary of metadata that can be attached to the run sample. """ return { 'redis_tune_on_startup': _TUNE_ON_STARTUP.value, 'redis_pipeline': _PIPELINES.value, 'threads': threads, 'db_shard_count': shards, 'total_shard_count': shards * _NUM_DATABASES.value, 'db_count': _NUM_DATABASES.value, 'redis_proxy_threads': proxy_threads or _PROXY_THREADS.value, 'redis_loadgen_clients': _LOADGEN_CLIENTS.value, 'pin_workers': _PIN_WORKERS.value, 'disable_cpus': _DISABLE_CPU_IDS.value, 'redis_enterprise_version': _VERSION, 'memtier_data_size': _DATA_SIZE.value, 'memtier_key_maximum': _LOAD_RECORDS.value, 'replication': _REPLICATION.value, } class ThroughputOptimizer: """Class that searches for the shard/proxy_thread count for best throughput. Attributes: client: Client that interacts with the Redis API. server_vms: List of server VMs. client_vms: List of client VMs. results: Matrix that records the search space of shards and proxy threads. min_threads: Keeps track of the optimal thread count used in previous run. """ def __init__(self, server_vms: List[_VM], client_vms: List[_VM]): self.server_vms: List[_VM] = server_vms self.client_vms: List[_VM] = client_vms self.min_threads: int = _MIN_THREADS.value self.client = HttpClient(server_vms) # Determines the search space for the optimization algorithm. We multiply # the size by 2 which should be a large enough search space. matrix_size = ( max( server_vms[0].num_cpus, FLAGS.enterprise_redis_proxy_threads or 0, FLAGS.enterprise_redis_shard_count or 0, ) * 2 ) self.results: _ThroughputSampleMatrix = [ [() for i in range(matrix_size)] for i in range(matrix_size) ] def _CreateAndLoadDatabases(self, shards: int) -> None: """Creates and loads all the databases needed for the run.""" self.client.CreateDatabases(shards) LoadDatabases( self.server_vms, self.client_vms, self.client.GetEndpoints(), shards ) def _FullRun( self, shard_count: int, proxy_thread_count: int ) -> _ThroughputSampleTuple: """Recreates databases if needed, then runs the test.""" logging.info( 'Starting new run with %s shards, %s proxy threads', shard_count, proxy_thread_count, ) server_vm = self.server_vms[0] # Recreate the DB if needed dbs = self.client.GetDatabases() if not dbs: self._CreateAndLoadDatabases(shard_count) elif shard_count != list(dbs.values())[0]['shards_count']: self.client.DeleteDatabases() self._CreateAndLoadDatabases(shard_count) TuneProxy(server_vm, proxy_thread_count) PinWorkers(self.server_vms, proxy_thread_count) # Optimize the number of threads and run the test results = Run(self.server_vms, self.client_vms, shard_count, proxy_thread_count, self.min_threads) # pyformat: disable self.min_threads = max( self.min_threads, int(results[1][-1].metadata['threads'] * _THREAD_OPTIMIZATION_RATIO), ) return results def _GetResult( self, shards: int, proxy_threads: int ) -> _ThroughputSampleTuple: if not self.results[shards - 1][proxy_threads - 1]: self.results[shards - 1][proxy_threads - 1] = self._FullRun( shards, proxy_threads ) return self.results[shards - 1][proxy_threads - 1] def _GetNeighborsToCheck( self, shards: int, proxy_threads: int ) -> List[Tuple[int, int]]: """Returns the shards/proxy_threads neighbor to check.""" vary_proxy_threads = [ (shards, proxy_threads - 1), (shards, proxy_threads + 1), ] vary_shards = [(shards - 1, proxy_threads), (shards + 1, proxy_threads)] return vary_proxy_threads + vary_shards def _GetOptimalNeighbor( self, shards: int, proxy_threads: int ) -> Tuple[int, int]: """Returns the shards/proxy_threads neighbor with the best throughput.""" optimal_shards = shards optimal_proxy_threads = proxy_threads optimal_throughput, _ = self._GetResult(shards, proxy_threads) for shards_count, proxy_threads_count in self._GetNeighborsToCheck( shards, proxy_threads ): if ( shards_count < 1 or shards_count > len(self.results) or proxy_threads_count < 1 or proxy_threads_count > len(self.results) ): continue throughput, _ = self._GetResult(shards_count, proxy_threads_count) if throughput > optimal_throughput: optimal_throughput = throughput optimal_shards = shards_count optimal_proxy_threads = proxy_threads_count return optimal_shards, optimal_proxy_threads def DoGraphSearch(self) -> _ThroughputSampleTuple: """Performs a graph search with the optimal shards AND proxy thread count. Performs a graph search through shards and proxy threads. If the current combination is a local maximum, finish and return the result. The DB needs to be recreated since shards can only be resized in multiples once created. Returns: Tuple of (optimal_throughput, samples). """ # Uses a heuristic for the number of shards and proxy threads per VM # Usually the optimal number is somewhere close to this. num_cpus = self.server_vms[0].num_cpus shard_count = max(num_cpus // 5, 1) # Per VM on 1 database proxy_thread_count = num_cpus - shard_count while True: logging.info( 'Checking shards: %s, proxy_threads: %s', shard_count, proxy_thread_count, ) optimal_shards, optimal_proxies = self._GetOptimalNeighbor( shard_count, proxy_thread_count ) if (shard_count, proxy_thread_count) == (optimal_shards, optimal_proxies): break shard_count = optimal_shards proxy_thread_count = optimal_proxies return self._GetResult(shard_count, proxy_thread_count) def DoLinearSearch(self) -> _ThroughputSampleTuple: """Performs a linear search using either shards or proxy threads.""" logging.info('Performing linear search through proxy threads OR shards.') max_throughput_tuple = (0, None) num_cpus = self.server_vms[0].num_cpus for i in range(1, num_cpus): if _SHARDS.value: result = self._FullRun(_SHARDS.value, i) else: result = self._FullRun(i, _PROXY_THREADS.value) if result[0] > max_throughput_tuple[0]: max_throughput_tuple = result return max_throughput_tuple def GetOptimalThroughput(self) -> _ThroughputSampleTuple: """Gets the optimal throughput for the Redis Enterprise cluster. Returns: Tuple of (optimal_throughput, samples). """ # If only optimizing proxy threads, do a linear search. if (_SHARDS.value and not _PROXY_THREADS.value) or ( _PROXY_THREADS.value and not _SHARDS.value ): return self.DoLinearSearch() return self.DoGraphSearch()