perfkitbenchmarker/linux_packages/redis_enterprise.py (672 lines of code) (raw):
# Copyright 2019 PerfKitBenchmarker Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing redis enterprise installation and cleanup functions.
TODO(user): Flags should be unified with memtier.py.
"""
import dataclasses
import itertools
import json
import logging
import posixpath
from typing import Any, Dict, List, Tuple
from absl import flags
from perfkitbenchmarker import background_tasks
from perfkitbenchmarker import data
from perfkitbenchmarker import errors
from perfkitbenchmarker import os_types
from perfkitbenchmarker import sample
from perfkitbenchmarker import virtual_machine
from perfkitbenchmarker import vm_util
import requests
FLAGS = flags.FLAGS
_LICENSE = flags.DEFINE_string(
'enterprise_redis_license_file',
'enterprise_redis_license',
'Name of the redis enterprise license file to use.',
)
_LICENSE_PATH = flags.DEFINE_string(
'enterprise_redis_license_path',
None,
'If none, defaults to local data directory joined with _LICENSE.',
)
_TUNE_ON_STARTUP = flags.DEFINE_boolean(
'enterprise_redis_tune_on_startup',
True,
'Whether to tune core config during startup.',
)
_PROXY_THREADS = flags.DEFINE_integer(
'enterprise_redis_proxy_threads',
None,
'Number of redis proxy threads to use.',
)
_SHARDS = flags.DEFINE_integer(
'enterprise_redis_shard_count',
None,
'Number of redis shards per database. Each shard is a redis thread.',
)
_LOAD_RECORDS = flags.DEFINE_integer(
'enterprise_redis_load_records',
1000000,
'Number of keys to pre-load into Redis. Use'
' --enterprise_redis_data_size_bytes to calculate how much space will be'
' used. Due to overhead, a ballpark estimate is 1KB record takes 1.5KB of'
' utilization. See'
' https://docs.redis.com/latest/rs/concepts/memory-performance/for more'
' info.',
)
_RUN_RECORDS = flags.DEFINE_integer(
'enterprise_redis_run_records',
1000000,
'Number of requests per loadgen client to send to the Redis server.',
)
_PIPELINES = flags.DEFINE_integer(
'enterprise_redis_pipeline', 9, 'Number of pipelines to use.'
)
_LOADGEN_CLIENTS = flags.DEFINE_integer(
'enterprise_redis_loadgen_clients', 24, 'Number of clients per loadgen vm.'
)
_MAX_THREADS = flags.DEFINE_integer(
'enterprise_redis_max_threads',
40,
'Maximum number of memtier threads to use.',
)
_MIN_THREADS = flags.DEFINE_integer(
'enterprise_redis_min_threads',
18,
'Minimum number of memtier threads to use.',
)
_THREAD_INCREMENT = flags.DEFINE_integer(
'enterprise_redis_thread_increment',
1,
'Number of memtier threads to increment by.',
)
_LATENCY_THRESHOLD = flags.DEFINE_integer(
'enterprise_redis_latency_threshold',
1100,
'The latency threshold in microseconds until the test stops.',
)
_PIN_WORKERS = flags.DEFINE_boolean(
'enterprise_redis_pin_workers',
False,
'Whether to pin the proxy threads after startup.',
)
_DISABLE_CPU_IDS = flags.DEFINE_list(
'enterprise_redis_disable_cpu_ids', None, 'List of cpus to disable by id.'
)
_DATA_SIZE = flags.DEFINE_integer(
'enterprise_redis_data_size_bytes',
100,
'The size of the data to write to redis enterprise.',
)
_NUM_DATABASES = flags.DEFINE_integer(
'enterprise_redis_db_count',
1,
'The number of databases to create on the cluster.',
)
_REPLICATION = flags.DEFINE_bool(
'enterprise_redis_db_replication',
False,
'If true, replicates each database to another node. Doubles the amount of '
'memory used by the database records.',
)
_MEMORY_SIZE_PERCENTAGE = flags.DEFINE_float(
'enterprise_redis_memory_size_percentage',
0.80,
'The percentage amount of memory to use out of all the available memory '
'reported by rladmin for provisioning databases. 1 means use all available '
'memory, which in practice tends to be error-prone.',
)
_VM = virtual_machine.VirtualMachine
_ThroughputSampleTuple = Tuple[float, List[sample.Sample]]
_ThroughputSampleMatrix = List[List[_ThroughputSampleTuple]]
_Json = Dict[str, Any]
_UidToJsonDict = Dict[int, _Json]
_VERSION = '6.2.4-54'
_PACKAGE_NAME = 'redis_enterprise'
_WORKING_DIR = '~/redislabs'
_RHEL_TAR = f'redislabs-{_VERSION}-rhel7-x86_64.tar'
_XENIAL_TAR = f'redislabs-{_VERSION}-xenial-amd64.tar'
_BIONIC_TAR = f'redislabs-{_VERSION}-bionic-amd64.tar'
_USERNAME = 'user@google.com'
_ONE_KILOBYTE = 1000
_ONE_MEGABYTE = _ONE_KILOBYTE * 1000
_ONE_GIGABYTE = _ONE_MEGABYTE * 1000
PREPROVISIONED_DATA = {
# These checksums correspond to version 6.2.4-54. To update, run
# 'sha256sum <redislabs-{VERSION}-rhel7-x86_64.tar>' and replace the values
# below.
_RHEL_TAR: (
'fb0b7aa5f115eb0bc2ac4fb958aaa7ad92bb260f2251a221a15b01fbdf4d2d14'
),
_XENIAL_TAR: (
'f78a6bb486f3dfb3e5ba9b5be86b1c880d0c76a08eb0dc4bd3aaaf9cc210406d'
),
_BIONIC_TAR: (
'dfe568958b243368c1f1c08c9cce9f660fa06e1bce38fa88f90503e344466927'
),
}
_THREAD_OPTIMIZATION_RATIO = 0.75
def _GetTarName() -> str | None:
"""Returns the Redis Enterprise package to use depending on the os.
For information about available packages, see
https://redislabs.com/redis-enterprise/software/downloads/.
"""
if FLAGS.os_type in [os_types.RHEL, os_types.AMAZONLINUX2, os_types.CENTOS7]:
return _RHEL_TAR
if FLAGS.os_type in [os_types.UBUNTU1604, os_types.DEBIAN, os_types.DEBIAN9]:
return _XENIAL_TAR
if FLAGS.os_type == os_types.UBUNTU1804:
return _BIONIC_TAR
def Install(vm: _VM) -> None:
"""Installs Redis Enterprise package on the VM."""
vm.InstallPackages('wget')
vm.RemoteCommand(f'mkdir -p {_WORKING_DIR}')
# Check for the license in the data directory if a path isn't specified.
license_path = _LICENSE_PATH.value
if not license_path:
license_path = data.ResourcePath(_LICENSE.value)
vm.PushFile(license_path, posixpath.join(_WORKING_DIR, _LICENSE.value))
# Check for the tarfile in the data directory first.
vm.InstallPreprovisionedPackageData(
_PACKAGE_NAME, [_GetTarName()], _WORKING_DIR
)
vm.RemoteCommand(
'cd {dir} && sudo tar xvf {tar}'.format(
dir=_WORKING_DIR, tar=_GetTarName()
)
)
if FLAGS.os_type == os_types.UBUNTU1804:
# Fix Ubuntu 18.04 DNS conflict
vm.RemoteCommand(
'echo "DNSStubListener=no" | sudo tee -a /etc/systemd/resolved.conf'
)
vm.RemoteCommand('sudo mv /etc/resolv.conf /etc/resolv.conf.orig')
vm.RemoteCommand(
'sudo ln -s /run/systemd/resolve/resolv.conf /etc/resolv.conf'
)
vm.RemoteCommand('sudo service systemd-resolved restart')
install_cmd = './install.sh -y'
if not _TUNE_ON_STARTUP.value:
install_cmd = 'CONFIG_systune=no ./install.sh -y -n'
vm.RemoteCommand(
'cd {dir} && sudo {install}'.format(dir=_WORKING_DIR, install=install_cmd)
)
def _JoinCluster(server_vm: _VM, vm: _VM) -> None:
"""Joins a Redis Enterprise cluster."""
logging.info('Joining redis enterprise cluster.')
vm.RemoteCommand(
'sudo /opt/redislabs/bin/rladmin cluster join '
'nodes {server_vm_ip} '
'username {username} '
'password {password} '.format(
server_vm_ip=server_vm.internal_ip,
username=_USERNAME,
password=FLAGS.run_uri,
)
)
def CreateCluster(vms: List[_VM]) -> None:
"""Creates a Redis Enterprise cluster on the VM."""
logging.info('Creating redis enterprise cluster.')
vms[0].RemoteCommand(
'sudo /opt/redislabs/bin/rladmin cluster create '
'license_file {license_file} '
'name redis-cluster '
'username {username} '
'password {password} '.format(
license_file=posixpath.join(_WORKING_DIR, _LICENSE.value),
username=_USERNAME,
password=FLAGS.run_uri,
)
)
for vm in vms[1:]:
_JoinCluster(vms[0], vm)
def OfflineCores(vms: List[_VM]) -> None:
"""Offlines specific cores."""
def _Offline(vm):
for cpu_id in _DISABLE_CPU_IDS.value or []:
vm.RemoteCommand(
'sudo bash -c "echo 0 > /sys/devices/system/cpu/cpu%s/online"'
% cpu_id
)
background_tasks.RunThreaded(_Offline, vms)
def TuneProxy(vm: _VM, proxy_threads: int | None = None) -> None:
"""Tunes the number of Redis proxies on the cluster."""
proxy_threads = proxy_threads or _PROXY_THREADS.value
vm.RemoteCommand(
'sudo /opt/redislabs/bin/rladmin tune '
'proxy all '
f'max_threads {proxy_threads} '
f'threads {proxy_threads} '
)
vm.RemoteCommand('sudo /opt/redislabs/bin/dmc_ctl restart')
def PinWorkers(vms: List[_VM], proxy_threads: int | None = None) -> None:
"""Splits the Redis worker threads across the NUMA nodes evenly.
This function is no-op if --enterprise_redis_pin_workers is not set.
Args:
vms: The VMs with the Redis workers to pin.
proxy_threads: The number of proxy threads per VM.
"""
if not _PIN_WORKERS.value:
return
proxy_threads = proxy_threads or _PROXY_THREADS.value
def _Pin(vm):
numa_nodes = vm.CheckLsCpu().numa_node_count
proxies_per_node = proxy_threads // numa_nodes
for node in range(numa_nodes):
node_cpu_list = vm.RemoteCommand(
'cat /sys/devices/system/node/node%d/cpulist' % node
)[0].strip()
# List the PIDs of the Redis worker processes and pin a sliding window of
# `proxies_per_node` workers to the NUMA nodes in increasing order.
vm.RemoteCommand(
r'sudo /opt/redislabs/bin/dmc-cli -ts root list | '
r'grep worker | '
r'head -n -{proxies_already_partitioned} | '
r'tail -n {proxies_per_node} | '
r"awk '"
r'{{printf "%i\n",$3}}'
r"' | "
r'xargs -i sudo taskset -pc {node_cpu_list} {{}} '.format(
proxies_already_partitioned=proxies_per_node * node,
proxies_per_node=proxies_per_node,
node_cpu_list=node_cpu_list,
)
)
background_tasks.RunThreaded(_Pin, vms)
def GetDatabaseMemorySize(vm: _VM) -> int:
"""Gets the available memory (bytes) that can be used to provision databases."""
output, _ = vm.RemoteCommand('sudo /opt/redislabs/bin/rladmin status')
# See tests/data/redis_enterprise_cluster_output.txt
node_output = output.splitlines()[2]
provisional_ram = node_output.split()[7]
size_gb = float(provisional_ram.split('/')[0].strip('GB')) * (
_MEMORY_SIZE_PERCENTAGE.value
)
return int(size_gb * _ONE_GIGABYTE)
@dataclasses.dataclass(frozen=True)
class LoadRequest:
key_minimum: int
key_maximum: int
redis_port: int
cluster_mode: bool
server_ip: str
def _BuildLoadCommand(request: LoadRequest) -> str:
"""Returns the command used to load the database."""
command = (
'sudo /opt/redislabs/bin/memtier_benchmark '
f'-s {request.server_ip} '
f'-a {FLAGS.run_uri} '
f'-p {str(request.redis_port)} '
'-t 1 ' # Set -t and -c to 1 to avoid duplicated work in writing the same
'-c 1 ' # key/value pairs repeatedly.
'--ratio 1:0 '
'--pipeline 100 '
f'-d {str(_DATA_SIZE.value)} '
'--key-pattern S:S '
f'--key-minimum {request.key_minimum} '
f'--key-maximum {request.key_maximum} '
'-n allkeys '
)
if request.cluster_mode:
command += '--cluster-mode'
return command
def _LoadDatabaseSingleVM(load_vm: _VM, request: LoadRequest) -> None:
"""Loads the DB from a single VM."""
command = _BuildLoadCommand(request)
logging.info('Loading database with %s', request)
load_vm.RemoteCommand(command)
def LoadDatabases(
redis_vms: List[_VM],
load_vms: List[_VM],
endpoints: List[Tuple[str, int]],
shards: int | None = None,
) -> None:
"""Loads the databases before performing tests."""
vms = load_vms + redis_vms
load_requests = []
load_records_per_vm = _LOAD_RECORDS.value // len(vms)
shards = shards or _SHARDS.value
cluster_mode = shards > 1
for endpoint, port in endpoints:
for i, _ in enumerate(vms):
load_requests.append((
vms[i],
LoadRequest(
key_minimum=max(i * load_records_per_vm, 1),
key_maximum=(i + 1) * load_records_per_vm,
redis_port=port,
cluster_mode=cluster_mode,
server_ip=endpoint,
),
))
background_tasks.RunThreaded(
_LoadDatabaseSingleVM, [(arg, {}) for arg in load_requests]
)
class HttpClient:
"""HTTP Client for interacting with Redis REST API."""
def __init__(self, server_vms: List[_VM]):
self.vms = server_vms
self.api_base_url = f'https://{server_vms[0].ip_address}:9443'
self.session = requests.Session()
self.session.auth = (_USERNAME, FLAGS.run_uri)
self.provisional_memory = 0
def _LogCurlifiedCommand(self, response: requests.Response) -> None:
"""Logs the version of the request that can be run from curl."""
request = response.request
command = f'curl -v -k -u {_USERNAME}:{FLAGS.run_uri} -X {request.method} '
if request.body:
body = request.body.decode('UTF-8')
command += f'-d "{body}" -H "Content-type: application/json" '
command += f'{request.url}'
logging.info('Making API call equivalent to this curl command: %s', command)
def GetDatabases(self) -> _UidToJsonDict | None:
"""Gets the database object(s) running in the cluster.
Returns:
A dictionary of database objects by uid, per
https://docs.redis.com/latest/rs/references/rest-api/objects/bdb/.
Raises:
errors.Benchmarks.RunError: If the number of databases found does not
match --enterprise_redis_db_count.
"""
logging.info('Getting Redis Enterprise databases.')
r = self.session.get(f'{self.api_base_url}/v1/bdbs', verify=False)
self._LogCurlifiedCommand(r)
results = r.json()
return {result['uid']: result for result in results}
def GetEndpoints(self) -> List[Tuple[str, int]]:
"""Returns a list of (ip, port) tuples."""
endpoints = []
for db in self.GetDatabases().values():
host = db['endpoints'][0]['addr'][0]
port = db['endpoints'][0]['port']
endpoints.append((host, port))
logging.info('Database endpoints: %s', endpoints)
return endpoints
def GetDatabase(self, uid: int) -> _Json | None:
"""Returns the database JSON object corresponding to uid."""
logging.info('Getting Redis Enterprise database (uid: %s).', uid)
all_databases = self.GetDatabases()
return all_databases.get(uid, None)
@vm_util.Retry(
poll_interval=5,
retryable_exceptions=(errors.Resource.RetryableCreationError,),
)
def WaitForDatabaseUp(self, uid: int) -> None:
"""Waits for the Redis Enterprise database to become active."""
db = self.GetDatabase(uid)
if not db:
raise errors.Benchmarks.RunError(
'Database %s does not exist, expected to be waiting for startup.'
)
if db['status'] != 'active':
raise errors.Resource.RetryableCreationError()
def CreateDatabase(self, shards: int | None = None) -> _Json:
"""Creates a new Redis Enterprise database.
See https://docs.redis.com/latest/rs/references/rest-api/objects/bdb/.
Args:
shards: Number of shards for the database. In a clustered setup, shards
will be distributed evenly across nodes.
Returns:
Returns the JSON object corresponding to the database that was created.
"""
db_shards = shards or _SHARDS.value
if not self.provisional_memory:
self.provisional_memory = GetDatabaseMemorySize(self.vms[0])
per_db_memory_size = int(
self.provisional_memory * len(self.vms) / _NUM_DATABASES.value
)
content = {
'name': 'redisdb',
'type': 'redis',
'memory_size': per_db_memory_size,
'proxy_policy': 'all-master-shards',
'sharding': False,
'authentication_redis_pass': FLAGS.run_uri,
'replication': _REPLICATION.value,
}
if db_shards > 1:
content.update({
'sharding': True,
'shards_count': db_shards,
'shards_placement': 'sparse',
'oss_cluster': True,
'shard_key_regex': [
{'regex': '.*\\{(?<tag>.*)\\}.*'},
{'regex': '(?<tag>.*)'}],
}) # pyformat: disable
logging.info('Creating Redis Enterprise database with %s.', content)
r = self.session.post(
f'{self.api_base_url}/v1/bdbs', json=content, verify=False
)
self._LogCurlifiedCommand(r)
if r.status_code != 200:
raise errors.Benchmarks.RunError(
f'Unable to create database: status code: {r.status_code}, '
f'reason {r.reason}.'
)
self.WaitForDatabaseUp(r.json()['uid'])
logging.info('Finished creating Redis Enterprise database %s.', r.json())
return r.json()
def CreateDatabases(self, shards: int | None = None) -> None:
"""Creates all databases with the specified number of shards."""
for _ in range(_NUM_DATABASES.value):
self.CreateDatabase(shards)
@vm_util.Retry(
poll_interval=5,
retryable_exceptions=(errors.Resource.RetryableDeletionError,),
)
def DeleteDatabase(self, uid: int) -> None:
"""Deletes the database from the cluster."""
logging.info('Deleting Redis Enterprise database (uid: %s).', uid)
r = self.session.delete(f'{self.api_base_url}/v1/bdbs/{uid}', verify=False)
self._LogCurlifiedCommand(r)
if self.GetDatabase(uid) is not None:
logging.info('Waiting for DB to finish deleting.')
raise errors.Resource.RetryableDeletionError()
def DeleteDatabases(self) -> None:
"""Deletes all databases."""
for uid in self.GetDatabases().keys():
self.DeleteDatabase(uid)
def _BuildRunCommand(
host: str, threads: int, port: int, shards: int | None = None
) -> str:
"""Spawns a memtier_benchmark on the load_vm against the redis_vm:port.
Args:
host: The target IP of the memtier_benchmark
threads: The number of threads to run in this memtier_benchmark process.
port: The port to target on the redis_vm.
shards: The number of shards per database.
Returns:
Command to issue to the redis server.
"""
if threads == 0:
return None
result = (
'sudo /opt/redislabs/bin/memtier_benchmark '
f'-s {host} '
f'-a {FLAGS.run_uri} '
f'-p {str(port)} '
f'-t {str(threads)} '
'--ratio 1:1 '
f'--pipeline {str(_PIPELINES.value)} '
f'-c {str(_LOADGEN_CLIENTS.value)} '
f'-d {str(_DATA_SIZE.value)} '
'--test-time 30 '
'--key-minimum 1 '
f'--key-maximum {str(_LOAD_RECORDS.value)} '
)
shards = shards or _SHARDS.value
if shards > 1:
result += '--cluster-mode '
return result
@dataclasses.dataclass(frozen=True)
class Result:
"""Individual throughput and latency result."""
throughput: int
latency_usec: int
metadata: Dict[str, Any]
def ToSample(self, metadata: Dict[str, Any]) -> sample.Sample:
"""Returns throughput sample attached with the given metadata."""
self.metadata.update(metadata)
return sample.Sample('throughput', self.throughput, 'ops/s', self.metadata)
def ParseResults(output: str) -> List[Result]:
"""Parses the result from the cluster statistics API."""
output_json = json.loads(output)
results = []
for interval in output_json.get('intervals'):
results.append(
Result(interval.get('total_req'), interval.get('avg_latency'), interval)
)
return results
def Run(
redis_vms: List[_VM],
load_vms: List[_VM],
shards: int | None = None,
proxy_threads: int | None = None,
memtier_threads: int | None = None,
) -> _ThroughputSampleTuple:
"""Run memtier against enterprise redis and measure latency and throughput.
This function runs memtier against the redis server vm with increasing memtier
threads until one of following conditions is reached:
- FLAGS.enterprise_redis_max_threads is reached
- FLAGS.enterprise_redis_latency_threshold is reached
Args:
redis_vms: Redis server vms.
load_vms: Memtier load vms.
shards: The per-DB shard count for this run.
proxy_threads: The per-VM proxy thread count for this run.
memtier_threads: If provided, overrides --enterprise_redis_min_threads.
Returns:
A tuple of (max_throughput_under_1ms, list of sample.Sample objects).
"""
# TODO(liubrandon): Break up this function as it's getting rather long.
results = []
cur_max_latency = 0.0
latency_threshold = _LATENCY_THRESHOLD.value
shards = shards or _SHARDS.value
threads = memtier_threads or _MIN_THREADS.value
max_threads = _MAX_THREADS.value
max_throughput_for_completion_latency_under_1ms = 0.0
client = HttpClient(redis_vms)
endpoints = client.GetEndpoints()
redis_vm = redis_vms[0]
# Validate before running
if len(set(endpoints)) < _NUM_DATABASES.value:
raise errors.Benchmarks.RunError(
f'Wrong number of unique endpoints {endpoints}, '
f'expected {_NUM_DATABASES.value}.'
)
if threads > max_threads:
raise errors.Benchmarks.RunError(
'min threads %s higher than max threads %s, '
'raise --enterprise_redis_max_threads'
)
while cur_max_latency < latency_threshold and threads <= max_threads:
# Set up run commands
run_cmds = [
_BuildRunCommand(endpoint, threads, port, shards)
for endpoint, port in endpoints
]
args = [(arg, {}) for arg in itertools.product(load_vms, run_cmds)]
# 30 sec for throughput to stabilize and 10 sec of data.
measurement_command = (
'sleep 25 && curl -v -k -u {user}:{password} '
'https://localhost:9443/v1/cluster/stats?interval=1sec > ~/output'
.format(
user=_USERNAME,
password=FLAGS.run_uri,
)
)
args += [((redis_vm, measurement_command), {})]
# Run
background_tasks.RunThreaded(
lambda vm, command: vm.RemoteCommand(command), args
)
stdout, _ = redis_vm.RemoteCommand('cat ~/output')
# Parse results and iterate
metadata = GetMetadata(shards, threads, proxy_threads)
run_results = ParseResults(stdout)
for result in run_results:
results.append(result.ToSample(metadata))
latency = result.latency_usec
cur_max_latency = max(cur_max_latency, latency)
if latency < 1000:
max_throughput_for_completion_latency_under_1ms = max(
max_throughput_for_completion_latency_under_1ms, result.throughput
)
logging.info(
'Threads : %d (%f ops/sec, %f ms latency) < %f ms latency',
threads,
result.throughput,
latency,
latency_threshold,
)
threads += _THREAD_INCREMENT.value
if cur_max_latency >= 1000:
results.append(
sample.Sample(
'max_throughput_for_completion_latency_under_1ms',
max_throughput_for_completion_latency_under_1ms,
'ops/s',
metadata,
)
)
logging.info(
'Max throughput under 1ms: %s ops/sec.',
max_throughput_for_completion_latency_under_1ms,
)
return max_throughput_for_completion_latency_under_1ms, results
def GetMetadata(
shards: int, threads: int, proxy_threads: int
) -> Dict[str, Any]:
"""Returns metadata associated with the run.
Args:
shards: The shard count per database.
threads: The thread count used by the memtier client.
proxy_threads: The proxy thread count used on the redis cluster.
Returns:
A dictionary of metadata that can be attached to the run sample.
"""
return {
'redis_tune_on_startup': _TUNE_ON_STARTUP.value,
'redis_pipeline': _PIPELINES.value,
'threads': threads,
'db_shard_count': shards,
'total_shard_count': shards * _NUM_DATABASES.value,
'db_count': _NUM_DATABASES.value,
'redis_proxy_threads': proxy_threads or _PROXY_THREADS.value,
'redis_loadgen_clients': _LOADGEN_CLIENTS.value,
'pin_workers': _PIN_WORKERS.value,
'disable_cpus': _DISABLE_CPU_IDS.value,
'redis_enterprise_version': _VERSION,
'memtier_data_size': _DATA_SIZE.value,
'memtier_key_maximum': _LOAD_RECORDS.value,
'replication': _REPLICATION.value,
}
class ThroughputOptimizer:
"""Class that searches for the shard/proxy_thread count for best throughput.
Attributes:
client: Client that interacts with the Redis API.
server_vms: List of server VMs.
client_vms: List of client VMs.
results: Matrix that records the search space of shards and proxy threads.
min_threads: Keeps track of the optimal thread count used in previous run.
"""
def __init__(self, server_vms: List[_VM], client_vms: List[_VM]):
self.server_vms: List[_VM] = server_vms
self.client_vms: List[_VM] = client_vms
self.min_threads: int = _MIN_THREADS.value
self.client = HttpClient(server_vms)
# Determines the search space for the optimization algorithm. We multiply
# the size by 2 which should be a large enough search space.
matrix_size = (
max(
server_vms[0].num_cpus,
FLAGS.enterprise_redis_proxy_threads or 0,
FLAGS.enterprise_redis_shard_count or 0,
)
* 2
)
self.results: _ThroughputSampleMatrix = [
[() for i in range(matrix_size)] for i in range(matrix_size)
]
def _CreateAndLoadDatabases(self, shards: int) -> None:
"""Creates and loads all the databases needed for the run."""
self.client.CreateDatabases(shards)
LoadDatabases(
self.server_vms, self.client_vms, self.client.GetEndpoints(), shards
)
def _FullRun(
self, shard_count: int, proxy_thread_count: int
) -> _ThroughputSampleTuple:
"""Recreates databases if needed, then runs the test."""
logging.info(
'Starting new run with %s shards, %s proxy threads',
shard_count,
proxy_thread_count,
)
server_vm = self.server_vms[0]
# Recreate the DB if needed
dbs = self.client.GetDatabases()
if not dbs:
self._CreateAndLoadDatabases(shard_count)
elif shard_count != list(dbs.values())[0]['shards_count']:
self.client.DeleteDatabases()
self._CreateAndLoadDatabases(shard_count)
TuneProxy(server_vm, proxy_thread_count)
PinWorkers(self.server_vms, proxy_thread_count)
# Optimize the number of threads and run the test
results = Run(self.server_vms,
self.client_vms,
shard_count,
proxy_thread_count,
self.min_threads) # pyformat: disable
self.min_threads = max(
self.min_threads,
int(results[1][-1].metadata['threads'] * _THREAD_OPTIMIZATION_RATIO),
)
return results
def _GetResult(
self, shards: int, proxy_threads: int
) -> _ThroughputSampleTuple:
if not self.results[shards - 1][proxy_threads - 1]:
self.results[shards - 1][proxy_threads - 1] = self._FullRun(
shards, proxy_threads
)
return self.results[shards - 1][proxy_threads - 1]
def _GetNeighborsToCheck(
self, shards: int, proxy_threads: int
) -> List[Tuple[int, int]]:
"""Returns the shards/proxy_threads neighbor to check."""
vary_proxy_threads = [
(shards, proxy_threads - 1),
(shards, proxy_threads + 1),
]
vary_shards = [(shards - 1, proxy_threads), (shards + 1, proxy_threads)]
return vary_proxy_threads + vary_shards
def _GetOptimalNeighbor(
self, shards: int, proxy_threads: int
) -> Tuple[int, int]:
"""Returns the shards/proxy_threads neighbor with the best throughput."""
optimal_shards = shards
optimal_proxy_threads = proxy_threads
optimal_throughput, _ = self._GetResult(shards, proxy_threads)
for shards_count, proxy_threads_count in self._GetNeighborsToCheck(
shards, proxy_threads
):
if (
shards_count < 1
or shards_count > len(self.results)
or proxy_threads_count < 1
or proxy_threads_count > len(self.results)
):
continue
throughput, _ = self._GetResult(shards_count, proxy_threads_count)
if throughput > optimal_throughput:
optimal_throughput = throughput
optimal_shards = shards_count
optimal_proxy_threads = proxy_threads_count
return optimal_shards, optimal_proxy_threads
def DoGraphSearch(self) -> _ThroughputSampleTuple:
"""Performs a graph search with the optimal shards AND proxy thread count.
Performs a graph search through shards and proxy threads. If the current
combination is a local maximum, finish and return the result.
The DB needs to be recreated since shards can only be resized in multiples
once created.
Returns:
Tuple of (optimal_throughput, samples).
"""
# Uses a heuristic for the number of shards and proxy threads per VM
# Usually the optimal number is somewhere close to this.
num_cpus = self.server_vms[0].num_cpus
shard_count = max(num_cpus // 5, 1) # Per VM on 1 database
proxy_thread_count = num_cpus - shard_count
while True:
logging.info(
'Checking shards: %s, proxy_threads: %s',
shard_count,
proxy_thread_count,
)
optimal_shards, optimal_proxies = self._GetOptimalNeighbor(
shard_count, proxy_thread_count
)
if (shard_count, proxy_thread_count) == (optimal_shards, optimal_proxies):
break
shard_count = optimal_shards
proxy_thread_count = optimal_proxies
return self._GetResult(shard_count, proxy_thread_count)
def DoLinearSearch(self) -> _ThroughputSampleTuple:
"""Performs a linear search using either shards or proxy threads."""
logging.info('Performing linear search through proxy threads OR shards.')
max_throughput_tuple = (0, None)
num_cpus = self.server_vms[0].num_cpus
for i in range(1, num_cpus):
if _SHARDS.value:
result = self._FullRun(_SHARDS.value, i)
else:
result = self._FullRun(i, _PROXY_THREADS.value)
if result[0] > max_throughput_tuple[0]:
max_throughput_tuple = result
return max_throughput_tuple
def GetOptimalThroughput(self) -> _ThroughputSampleTuple:
"""Gets the optimal throughput for the Redis Enterprise cluster.
Returns:
Tuple of (optimal_throughput, samples).
"""
# If only optimizing proxy threads, do a linear search.
if (_SHARDS.value and not _PROXY_THREADS.value) or (
_PROXY_THREADS.value and not _SHARDS.value
):
return self.DoLinearSearch()
return self.DoGraphSearch()