esrally/driver/driver.py (1,843 lines of code) (raw):
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
import collections
import concurrent.futures
import datetime
import itertools
import logging
import math
import multiprocessing
import queue
import sys
import threading
import time
from dataclasses import dataclass
from enum import Enum
from io import BytesIO
from typing import Callable, Optional
import thespian.actors
from esrally import (
PROGRAM_NAME,
actor,
client,
config,
exceptions,
metrics,
paths,
telemetry,
track,
types,
)
from esrally.client import delete_api_keys
from esrally.driver import runner, scheduler
from esrally.track import TrackProcessorRegistry, load_track, load_track_plugins
from esrally.utils import console, convert, net
##################################
#
# Messages sent between drivers
#
##################################
class PrepareBenchmark:
"""
Initiates preparation steps for a benchmark. The benchmark should only be started after StartBenchmark is sent.
"""
def __init__(self, config: types.Config, track):
"""
:param config: Rally internal configuration object.
:param track: The track to use.
"""
self.config = config
self.track = track
class StartBenchmark:
pass
class Bootstrap:
"""
Prompts loading of track code on new actors
"""
def __init__(self, cfg: types.Config, worker_id=None):
self.config = cfg
self.worker_id = worker_id
class PrepareTrack:
"""
Initiates preparation of a track.
"""
def __init__(self, track):
"""
:param track: The track to use.
"""
self.track = track
class TrackPrepared:
pass
class StartTaskLoop:
def __init__(self, track_name, cfg: types.Config):
self.track_name = track_name
self.cfg = cfg
class DoTask:
def __init__(self, task, cfg: types.Config):
self.task = task
self.cfg = cfg
@dataclass(frozen=True)
class WorkerTask:
"""
Unit of work that should be completed by the low-level TaskExecutionActor
"""
func: Callable
params: dict
class ReadyForWork:
pass
class WorkerIdle:
pass
class PreparationComplete:
def __init__(self, distribution_flavor, distribution_version, revision):
self.distribution_flavor = distribution_flavor
self.distribution_version = distribution_version
self.revision = revision
class StartWorker:
"""
Starts a worker.
"""
def __init__(self, worker_id, config: types.Config, track, client_allocations, client_contexts):
"""
:param worker_id: Unique (numeric) id of the worker.
:param config: Rally internal configuration object.
:param track: The track to use.
:param client_allocations: A structure describing which clients need to run which tasks.
:param client_contexts: A dict ``ClientContext`` objects keyed by client ID
"""
self.worker_id = worker_id
self.config = config
self.track = track
self.client_allocations = client_allocations
self.client_contexts = client_contexts
class Drive:
"""
Tells a load generator to drive (either after a join point or initially).
"""
def __init__(self, client_start_timestamp):
self.client_start_timestamp = client_start_timestamp
class CompleteCurrentTask:
"""
Tells a load generator to prematurely complete its current task. This is used to model task dependencies for parallel tasks (i.e. if a
specific task that is marked accordingly in the track finishes, it will also signal termination of all other tasks in the same parallel
element).
"""
class UpdateSamples:
"""
Used to send samples from a load generator node to the master.
"""
def __init__(self, client_id, samples):
self.client_id = client_id
self.samples = samples
class JoinPointReached:
"""
Tells the master that a load generator has reached a join point. Used for coordination across multiple load generators.
"""
def __init__(self, worker_id, task):
self.worker_id = worker_id
# Using perf_counter here is fine even in the distributed case. Although we "leak" this value to other
# machines, we will only ever interpret this value on the same machine (see `Drive` and the implementation
# in `Driver#joinpoint_reached()`).
self.worker_timestamp = time.perf_counter()
self.task = task
class BenchmarkComplete:
"""
Indicates that the benchmark is complete.
"""
def __init__(self, metrics):
self.metrics = metrics
class TaskFinished:
def __init__(self, metrics, next_task_scheduled_in):
self.metrics = metrics
self.next_task_scheduled_in = next_task_scheduled_in
class DriverActor(actor.RallyActor):
RESET_RELATIVE_TIME_MARKER = "reset_relative_time"
WAKEUP_INTERVAL_SECONDS = 1
# post-process request metrics every N seconds and send it to the metrics store
POST_PROCESS_INTERVAL_SECONDS = 30
"""
Coordinates all workers. This is actually only a thin actor wrapper layer around ``Driver`` which does the actual work.
"""
def __init__(self):
super().__init__()
self.benchmark_actor = None
self.driver = None
self.status = "init"
self.post_process_timer = 0
self.cluster_details = {}
def receiveMsg_PoisonMessage(self, poisonmsg, sender):
self.logger.error("Main driver received a fatal indication from a load generator (%s). Shutting down.", poisonmsg.details)
self.driver.close()
self.send(self.benchmark_actor, actor.BenchmarkFailure("Fatal track or load generator indication", poisonmsg.details))
def receiveMsg_BenchmarkFailure(self, msg, sender):
self.logger.error("Main driver received a fatal exception from a load generator. Shutting down.")
self.driver.close()
self.send(self.benchmark_actor, msg)
def receiveMsg_BenchmarkCancelled(self, msg, sender):
self.logger.info("Main driver received a notification that the benchmark has been cancelled.")
self.driver.close()
self.send(self.benchmark_actor, msg)
def receiveMsg_ActorExitRequest(self, msg, sender):
self.logger.info("Main driver received ActorExitRequest and will terminate all load generators.")
self.status = "exiting"
def receiveMsg_ChildActorExited(self, msg, sender):
# is it a worker?
if msg.childAddress in self.driver.workers:
worker_index = self.driver.workers.index(msg.childAddress)
if self.status == "exiting":
self.logger.debug("Worker [%d] has exited.", worker_index)
else:
self.logger.error("Worker [%d] has exited prematurely. Aborting benchmark.", worker_index)
self.send(self.benchmark_actor, actor.BenchmarkFailure(f"Worker [{worker_index}] has exited prematurely."))
else:
self.logger.debug("A track preparator has exited.")
def receiveUnrecognizedMessage(self, msg, sender):
self.logger.debug("Main driver received unknown message [%s] (ignoring).", str(msg))
@actor.no_retry("driver") # pylint: disable=no-value-for-parameter
def receiveMsg_PrepareBenchmark(self, msg, sender):
self.benchmark_actor = sender
self.driver = Driver(self, msg.config)
self.driver.prepare_benchmark(msg.track)
@actor.no_retry("driver") # pylint: disable=no-value-for-parameter
def receiveMsg_StartBenchmark(self, msg, sender):
self.benchmark_actor = sender
self.driver.start_benchmark()
self.wakeupAfter(datetime.timedelta(seconds=DriverActor.WAKEUP_INTERVAL_SECONDS))
@actor.no_retry("driver") # pylint: disable=no-value-for-parameter
def receiveMsg_TrackPrepared(self, msg, track_preparation_actor):
self.transition_when_all_children_responded(
track_preparation_actor, msg, expected_status=None, new_status=None, transition=self._after_track_prepared
)
@actor.no_retry("driver") # pylint: disable=no-value-for-parameter
def receiveMsg_JoinPointReached(self, msg, sender):
self.driver.joinpoint_reached(msg.worker_id, msg.worker_timestamp, msg.task)
@actor.no_retry("driver") # pylint: disable=no-value-for-parameter
def receiveMsg_UpdateSamples(self, msg, sender):
self.driver.update_samples(msg.samples)
@actor.no_retry("driver") # pylint: disable=no-value-for-parameter
def receiveMsg_WakeupMessage(self, msg, sender):
if msg.payload == DriverActor.RESET_RELATIVE_TIME_MARKER:
self.driver.reset_relative_time()
elif not self.driver.finished():
self.post_process_timer += DriverActor.WAKEUP_INTERVAL_SECONDS
if self.post_process_timer >= DriverActor.POST_PROCESS_INTERVAL_SECONDS:
self.post_process_timer = 0
self.driver.post_process_samples()
self.driver.update_progress_message()
self.wakeupAfter(datetime.timedelta(seconds=DriverActor.WAKEUP_INTERVAL_SECONDS))
def create_client(self, host, cfg: types.Config, worker_id):
worker = self.createActor(Worker, targetActorRequirements=self._requirements(host))
self.send(worker, Bootstrap(cfg, worker_id))
return worker
def start_worker(self, driver, worker_id, cfg: types.Config, track, allocations, client_contexts=None):
self.send(driver, StartWorker(worker_id, cfg, track, allocations, client_contexts))
def drive_at(self, driver, client_start_timestamp):
self.send(driver, Drive(client_start_timestamp))
def complete_current_task(self, driver):
self.send(driver, CompleteCurrentTask())
def on_task_finished(self, metrics, next_task_scheduled_in):
if next_task_scheduled_in > 0:
self.wakeupAfter(datetime.timedelta(seconds=next_task_scheduled_in), payload=DriverActor.RESET_RELATIVE_TIME_MARKER)
else:
self.driver.reset_relative_time()
self.send(self.benchmark_actor, TaskFinished(metrics, next_task_scheduled_in))
def _requirements(self, host):
if host == "localhost":
return {"coordinator": True}
else:
return {"ip": host}
def prepare_track(self, hosts, cfg: types.Config, track):
self.track = track
self.logger.info("Starting prepare track process on hosts [%s]", hosts)
self.children = [self._create_track_preparator(h) for h in hosts]
msg = Bootstrap(cfg)
for child in self.children:
self.send(child, msg)
@actor.no_retry("driver") # pylint: disable=no-value-for-parameter
def receiveMsg_ReadyForWork(self, msg, task_preparation_actor):
msg = PrepareTrack(self.track)
self.send(task_preparation_actor, msg)
def _create_track_preparator(self, host):
return self.createActor(TrackPreparationActor, targetActorRequirements=self._requirements(host))
def _after_track_prepared(self):
cluster_version = self.cluster_details["version"] if self.cluster_details else {}
# manually compiled versions don't expose build_flavor but Rally expects a value in telemetry devices
# we should default to trial/basic, but let's default to oss for now to avoid breaking the charts
build_flavor = cluster_version.get("build_flavor", "oss")
build_version = cluster_version.get("number", build_flavor)
build_hash = cluster_version.get("build_hash", build_flavor)
for child in self.children:
self.send(child, thespian.actors.ActorExitRequest())
self.children = []
self.send(
self.benchmark_actor,
PreparationComplete(
build_flavor,
build_version,
build_hash,
),
)
def on_benchmark_complete(self, metrics):
self.send(self.benchmark_actor, BenchmarkComplete(metrics))
def load_local_config(coordinator_config) -> types.Config:
cfg = config.auto_load_local_config(
coordinator_config,
additional_sections=[
# only copy the relevant bits
"track",
"driver",
"client",
# due to distribution version...
"mechanic",
"telemetry",
],
)
# set root path (normally done by the main entry point)
cfg.add(config.Scope.application, "node", "rally.root", paths.rally_root())
return cfg
class TaskExecutionActor(actor.RallyActor):
"""
This class should be used for long-running tasks, as it ensures they do not block the actor's messaging system
"""
def __init__(self):
super().__init__()
self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.executor_future = None
self.wakeup_interval = 5
self.task_preparation_actor = None
self.logger = logging.getLogger(__name__)
self.track_name = None
self.cfg: Optional[types.Config] = None
@actor.no_retry("task executor") # pylint: disable=no-value-for-parameter
def receiveMsg_StartTaskLoop(self, msg, sender):
self.task_preparation_actor = sender
self.track_name = msg.track_name
self.cfg = load_local_config(msg.cfg)
if self.cfg.opts("track", "test.mode.enabled"):
self.wakeup_interval = 0.5
track.load_track_plugins(self.cfg, self.track_name)
self.send(self.task_preparation_actor, ReadyForWork())
@actor.no_retry("task executor") # pylint: disable=no-value-for-parameter
def receiveMsg_DoTask(self, msg, sender):
# actor can arbitrarily execute code based on these messages. if anyone besides our parent sends a task, ignore
if sender != self.task_preparation_actor:
msg = (
f"TaskExecutionActor expected message from [{self.task_preparation_actor}]"
" but the received the following from [{sender}]: {vars(msg)}"
)
raise exceptions.RallyError(msg)
task = msg.task
if self.executor_future is not None:
msg = f"TaskExecutionActor received DoTask message [{vars(msg)}], but was already busy"
raise exceptions.RallyError(msg)
if task is None:
self.send(self.task_preparation_actor, WorkerIdle())
else:
# this is a potentially long-running operation so we offload it a background thread so we don't block
# the actor (e.g. logging works properly as log messages are forwarded timely).
self.executor_future = self.pool.submit(task.func, **task.params)
self.wakeupAfter(datetime.timedelta(seconds=self.wakeup_interval))
@actor.no_retry("task executor") # pylint: disable=no-value-for-parameter
def receiveMsg_WakeupMessage(self, msg, sender):
if self.executor_future is not None and self.executor_future.done():
e = self.executor_future.exception(timeout=0)
if e:
self.logger.exception("Worker failed. Notifying parent...", exc_info=e)
# the exception might be user-defined and not be on the load path of the original sender. Hence, it
# cannot be deserialized on the receiver so we convert it here to a plain string.
self.send(self.task_preparation_actor, actor.BenchmarkFailure("Error in task executor", str(e)))
else:
self.executor_future = None
self.send(self.task_preparation_actor, ReadyForWork())
else:
self.wakeupAfter(datetime.timedelta(seconds=self.wakeup_interval))
def receiveMsg_BenchmarkFailure(self, msg, sender):
# sent by our no_retry infrastructure; forward to master
self.send(self.task_preparation_actor, msg)
class TrackPreparationActor(actor.RallyActor):
class Status(Enum):
INITIALIZING = "initializing"
PROCESSOR_RUNNING = "processor running"
PROCESSOR_COMPLETE = "processor complete"
def __init__(self):
super().__init__()
self.processors = queue.Queue()
self.driver_actor = None
self.logger.info("Track Preparator started")
self.status = self.Status.INITIALIZING
self.children = []
self.tasks = []
self.cfg: Optional[types.Config] = None
self.data_root_dir = None
self.track = None
def receiveMsg_PoisonMessage(self, poisonmsg, sender):
self.logger.error("Track Preparator received a fatal indication from a load generator (%s). Shutting down.", poisonmsg.details)
self.send(self.driver_actor, actor.BenchmarkFailure("Fatal track preparation indication", poisonmsg.details))
@actor.no_retry("track preparator") # pylint: disable=no-value-for-parameter
def receiveMsg_Bootstrap(self, msg, sender):
self.driver_actor = sender
# load node-specific config to have correct paths available
self.cfg = load_local_config(msg.config)
# this instance of load_track occurs once per host, so install dependencies if necessary
load_track(self.cfg, install_dependencies=False)
self.send(self.driver_actor, ReadyForWork())
@actor.no_retry("track preparator") # pylint: disable=no-value-for-parameter
def receiveMsg_ActorExitRequest(self, msg, sender):
self.logger.debug("ActorExitRequest received. Forwarding to children")
for child in self.children:
self.send(child, msg)
@actor.no_retry("track preparator") # pylint: disable=no-value-for-parameter
def receiveMsg_BenchmarkFailure(self, msg, sender):
# sent by our generic worker; forward to parent
self.send(self.driver_actor, msg)
@actor.no_retry("track preparator") # pylint: disable=no-value-for-parameter
def receiveMsg_PrepareTrack(self, msg, sender):
assert self.cfg is not None
self.data_root_dir = self.cfg.opts("benchmarks", "local.dataset.cache")
tpr = TrackProcessorRegistry(self.cfg)
self.track = msg.track
self.logger.info("Preparing track [%s]", self.track.name)
self.logger.info("Reloading track [%s] to ensure plugins are up-to-date.", self.track.name)
# the track might have been loaded on a different machine (the coordinator machine) so we force a track
# update to ensure we use the latest version of plugins.
load_track(self.cfg)
load_track_plugins(self.cfg, self.track.name, register_track_processor=tpr.register_track_processor, force_update=True)
# we expect on_prepare_track can take a long time. seed a queue of tasks and delegate to child workers
self.children = [self._create_task_executor() for _ in range(num_cores(self.cfg))]
for processor in tpr.processors:
self.processors.put(processor)
self._seed_tasks(self.processors.get())
self.send_to_children_and_transition(
self, StartTaskLoop(self.track.name, self.cfg), self.Status.INITIALIZING, self.Status.PROCESSOR_RUNNING
)
def resume(self):
assert self.cfg is not None
if not self.processors.empty():
self._seed_tasks(self.processors.get())
self.send_to_children_and_transition(
self, StartTaskLoop(self.track.name, self.cfg), self.Status.PROCESSOR_COMPLETE, self.Status.PROCESSOR_RUNNING
)
else:
self.send(self.driver_actor, TrackPrepared())
def _seed_tasks(self, processor):
self.tasks = list(WorkerTask(func, params) for func, params in processor.on_prepare_track(self.track, self.data_root_dir))
def _create_task_executor(self):
return self.createActor(TaskExecutionActor)
@actor.no_retry("track preparator") # pylint: disable=no-value-for-parameter
def receiveMsg_ReadyForWork(self, msg, task_execution_actor):
assert self.cfg is not None
if self.tasks:
next_task = self.tasks.pop()
else:
next_task = None
new_msg = DoTask(next_task, self.cfg)
self.logger.debug("Track Preparator sending %s to %s", vars(new_msg), task_execution_actor)
self.send(task_execution_actor, new_msg)
@actor.no_retry("track preparator") # pylint: disable=no-value-for-parameter
def receiveMsg_WorkerIdle(self, msg, sender):
self.transition_when_all_children_responded(sender, msg, self.Status.PROCESSOR_RUNNING, self.Status.PROCESSOR_COMPLETE, self.resume)
def num_cores(cfg: types.Config):
return int(cfg.opts("system", "available.cores", mandatory=False, default_value=multiprocessing.cpu_count()))
ApiKey = collections.namedtuple("ApiKey", ["id", "secret"])
@dataclass
class ClientContext:
client_id: int
parent_worker_id: int
api_key: Optional[ApiKey] = None
class Driver:
def __init__(self, driver_actor, config: types.Config, es_client_factory_class=client.EsClientFactory):
"""
Coordinates all workers. It is technology-agnostic, i.e. it does not know anything about actors. To allow us to hook in an actor,
we provide a ``target`` parameter which will be called whenever some event has occurred. The ``target`` can use this to send
appropriate messages.
:param target: A target that will be notified of important events.
:param config: The current config object.
"""
self.logger = logging.getLogger(__name__)
self.driver_actor = driver_actor
self.config = config
self.es_client_factory = es_client_factory_class
self.default_sync_es_client = None
self.track = None
self.challenge = None
self.metrics_store = None
self.load_driver_hosts = []
self.workers = []
# which client ids are assigned to which workers?
self.clients_per_worker = {}
self.client_contexts = {}
self.generated_api_key_ids = []
self.progress_reporter = console.progress()
self.progress_counter = 0
self.quiet = False
self.allocations = None
self.raw_samples = []
self.most_recent_sample_per_client = {}
self.sample_post_processor = None
self.number_of_steps = 0
self.currently_completed = 0
self.workers_completed_current_step = {}
self.current_step = -1
self.tasks_per_join_point = None
self.complete_current_task_sent = False
self.telemetry = None
def create_es_clients(self):
all_hosts = self.config.opts("client", "hosts").all_hosts
distribution_version = self.config.opts("mechanic", "distribution.version", mandatory=False)
distribution_flavor = self.config.opts("mechanic", "distribution.flavor", mandatory=False)
es = {}
for cluster_name, cluster_hosts in all_hosts.items():
all_client_options = self.config.opts("client", "options").all_client_options
cluster_client_options = dict(all_client_options[cluster_name])
# Use retries to avoid aborts on long living connections for telemetry devices
cluster_client_options["retry_on_timeout"] = True
es[cluster_name] = self.es_client_factory(
cluster_hosts, cluster_client_options, distribution_version=distribution_version, distribution_flavor=distribution_flavor
).create()
return es
def prepare_telemetry(self, es, enable, index_names, data_stream_names, build_hash, serverless_mode, serverless_operator):
enabled_devices = self.config.opts("telemetry", "devices")
telemetry_params = self.config.opts("telemetry", "params")
log_root = paths.race_root(self.config)
es_default = es["default"]
if enable:
devices = [
telemetry.NodeStats(telemetry_params, es, self.metrics_store),
telemetry.ExternalEnvironmentInfo(es_default, self.metrics_store),
telemetry.ClusterEnvironmentInfo(es_default, self.metrics_store, build_hash),
telemetry.JvmStatsSummary(es_default, self.metrics_store),
telemetry.IndexStats(es_default, self.metrics_store),
telemetry.MlBucketProcessingTime(es_default, self.metrics_store),
telemetry.MasterNodeStats(telemetry_params, es_default, self.metrics_store),
telemetry.SegmentStats(log_root, es_default),
telemetry.CcrStats(telemetry_params, es, self.metrics_store),
telemetry.RecoveryStats(telemetry_params, es, self.metrics_store),
telemetry.ShardStats(telemetry_params, es, self.metrics_store),
telemetry.TransformStats(telemetry_params, es, self.metrics_store),
telemetry.SearchableSnapshotsStats(telemetry_params, es, self.metrics_store),
telemetry.DataStreamStats(telemetry_params, es, self.metrics_store),
telemetry.IngestPipelineStats(es, self.metrics_store),
telemetry.DiskUsageStats(telemetry_params, es_default, self.metrics_store, index_names, data_stream_names),
telemetry.BlobStoreStats(telemetry_params, es, self.metrics_store),
telemetry.GeoIpStats(es_default, self.metrics_store),
]
else:
devices = []
self.telemetry = telemetry.Telemetry(
enabled_devices,
devices=devices,
serverless_mode=serverless_mode,
serverless_operator=serverless_operator,
)
def wait_for_rest_api(self, es):
es_default = es["default"]
self.logger.info("Checking if REST API is available.")
if client.wait_for_rest_layer(es_default, max_attempts=40):
self.logger.info("REST API is available.")
else:
self.logger.error("REST API layer is not yet available. Stopping benchmark.")
raise exceptions.SystemSetupError("Elasticsearch REST API layer is not available.")
def retrieve_cluster_info(self, es):
try:
return es["default"].info()
except BaseException:
self.logger.exception("Could not retrieve cluster info on benchmark start")
return None
def retrieve_build_hash_from_nodes_info(self, es):
try:
nodes_info = es["default"].nodes.info(filter_path="**.build_hash")
nodes = nodes_info["nodes"]
# assumption: build hash is the same across all the nodes
first_node_id = next(iter(nodes))
return nodes[first_node_id]["build_hash"]
except BaseException:
self.logger.exception("Could not retrieve build hash from nodes info")
return None
def create_api_key(self, es, client_id):
self.logger.debug("Creating ES API key for client [%s].", client_id)
try:
api_key = client.create_api_key(es, client_id)
self.logger.debug("ES API key created for client [%s].", client_id)
# Store the API key ID for deletion upon benchmark completion
self.generated_api_key_ids.append(api_key["id"])
return api_key
except Exception as e:
self.logger.error("Unable to create API keys. Stopping benchmark.")
raise exceptions.SystemSetupError(e.message)
def prepare_benchmark(self, t):
self.track = t
self.challenge = select_challenge(self.config, self.track)
self.quiet = self.config.opts("system", "quiet.mode", mandatory=False, default_value=False)
downsample_factor = int(self.config.opts("reporting", "metrics.request.downsample.factor", mandatory=False, default_value=1))
self.metrics_store = metrics.metrics_store(cfg=self.config, track=self.track.name, challenge=self.challenge.name, read_only=False)
self.sample_post_processor = SamplePostprocessor(
self.metrics_store, downsample_factor, self.track.meta_data, self.challenge.meta_data
)
es_clients = self.create_es_clients()
self.default_sync_es_client = es_clients["default"]
skip_rest_api_check = self.config.opts("mechanic", "skip.rest.api.check")
uses_static_responses = self.config.opts("client", "options").uses_static_responses
serverless_mode = convert.to_bool(self.config.opts("driver", "serverless.mode", mandatory=False, default_value=False))
serverless_operator = convert.to_bool(self.config.opts("driver", "serverless.operator", mandatory=False, default_value=False))
build_hash = None
if skip_rest_api_check:
self.logger.info("Skipping REST API check as requested explicitly.")
elif uses_static_responses:
self.logger.info("Skipping REST API check as static responses are used.")
else:
if serverless_mode and not serverless_operator:
self.logger.info("Skipping REST API check while targetting serverless cluster with a public user.")
else:
self.wait_for_rest_api(es_clients)
self.driver_actor.cluster_details = self.retrieve_cluster_info(es_clients)
if serverless_mode:
# overwrite static serverless version number
self.driver_actor.cluster_details["version"]["number"] = "serverless"
if serverless_operator:
# overwrite build hash if running as operator
build_hash = self.retrieve_build_hash_from_nodes_info(es_clients)
self.logger.info("Retrieved actual build hash [%s] from serverless cluster.", build_hash)
self.driver_actor.cluster_details["version"]["build_hash"] = build_hash
# Avoid issuing any requests to the target cluster when static responses are enabled. The results
# are not useful and attempts to connect to a non-existing cluster just lead to exception traces in logs.
self.prepare_telemetry(
es_clients,
enable=not uses_static_responses,
index_names=self.track.index_names(),
data_stream_names=self.track.data_stream_names(),
build_hash=build_hash,
serverless_mode=serverless_mode,
serverless_operator=serverless_operator,
)
for host in self.config.opts("driver", "load_driver_hosts"):
host_config = {
# for simplicity we assume that all benchmark machines have the same specs
"cores": num_cores(self.config)
}
if host != "localhost":
host_config["host"] = net.resolve(host)
else:
host_config["host"] = host
self.load_driver_hosts.append(host_config)
self.driver_actor.prepare_track([h["host"] for h in self.load_driver_hosts], self.config, self.track)
def start_benchmark(self):
self.logger.info("Benchmark is about to start.")
# ensure relative time starts when the benchmark starts.
self.reset_relative_time()
self.logger.info("Attaching cluster-level telemetry devices.")
self.telemetry.on_benchmark_start()
self.logger.info("Cluster-level telemetry devices are now attached.")
allocator = Allocator(self.challenge.schedule)
self.allocations = allocator.allocations
self.number_of_steps = len(allocator.join_points) - 1
self.tasks_per_join_point = allocator.tasks_per_joinpoint
self.logger.info(
"Benchmark consists of [%d] steps executed by [%d] clients.",
self.number_of_steps,
len(self.allocations), # type: ignore[arg-type] # TODO remove the below ignore when introducing type hints
)
# avoid flooding the log if there are too many clients
if allocator.clients < 128:
self.logger.debug("Allocation matrix:\n%s", "\n".join([str(a) for a in self.allocations]))
create_api_keys = self.config.opts("client", "options").all_client_options["default"].get("create_api_key_per_client", None)
worker_assignments = calculate_worker_assignments(self.load_driver_hosts, allocator.clients)
worker_id = 0
for assignment in worker_assignments:
host = assignment["host"]
for clients in assignment["workers"]:
# don't assign workers without any clients
if len(clients) > 0:
self.logger.debug("Allocating worker [%d] on [%s] with [%d] clients.", worker_id, host, len(clients))
worker = self.driver_actor.create_client(host, self.config, worker_id)
client_allocations = ClientAllocations()
worker_client_contexts = {}
for client_id in clients:
client_allocations.add(client_id, self.allocations[client_id])
self.clients_per_worker[client_id] = worker_id
client_context = ClientContext(client_id=client_id, parent_worker_id=worker_id)
if create_api_keys:
resp = self.create_api_key(self.default_sync_es_client, client_id)
client_context.api_key = ApiKey(id=resp["id"], secret=resp["api_key"])
worker_client_contexts[client_id] = client_context
self.client_contexts[worker_id] = worker_client_contexts
self.driver_actor.start_worker(
worker, worker_id, self.config, self.track, client_allocations, client_contexts=worker_client_contexts
)
self.workers.append(worker)
worker_id += 1
self.update_progress_message()
def joinpoint_reached(self, worker_id, worker_local_timestamp, task_allocations):
self.currently_completed += 1
self.workers_completed_current_step[worker_id] = (worker_local_timestamp, time.perf_counter())
self.logger.debug(
"[%d/%d] workers reached join point [%d/%d].",
self.currently_completed,
len(self.workers),
self.current_step + 1,
self.number_of_steps,
)
if self.currently_completed == len(self.workers):
self.logger.info("All workers completed their tasks until join point [%d/%d].", self.current_step + 1, self.number_of_steps)
# we can go on to the next step
self.currently_completed = 0
self.complete_current_task_sent = False
# make a copy and reset early to avoid any race conditions from clients that reach a join point already while we are sending...
workers_curr_step = self.workers_completed_current_step
self.workers_completed_current_step = {}
self.update_progress_message(task_finished=True)
# clear per step
self.most_recent_sample_per_client = {}
self.current_step += 1
self.logger.debug("Postprocessing samples...")
self.post_process_samples()
if self.finished():
self.telemetry.on_benchmark_stop()
self.logger.info("All steps completed.")
# Some metrics store implementations return None because no external representation is required.
# pylint: disable=assignment-from-none
m = self.metrics_store.to_externalizable(clear=True)
self.logger.debug("Closing metrics store...")
self.metrics_store.close()
# immediately clear as we don't need it anymore and it can consume a significant amount of memory
self.metrics_store = None
if self.generated_api_key_ids:
self.logger.debug("Deleting auto-generated client API keys...")
try:
delete_api_keys(self.default_sync_es_client, self.generated_api_key_ids)
except exceptions.RallyError:
console.warn(
"Unable to delete auto-generated API keys. You may need to manually delete them. "
"Please check the logs for details."
)
self.logger.debug("Sending benchmark results...")
self.driver_actor.on_benchmark_complete(m)
else:
self.move_to_next_task(workers_curr_step)
else:
self.may_complete_current_task(task_allocations)
def move_to_next_task(self, workers_curr_step):
if self.config.opts("track", "test.mode.enabled"):
# don't wait if test mode is enabled and start the next task immediately.
waiting_period = 0
else:
# start the next task in one second (relative to master's timestamp)
#
# Assumption: We don't have a lot of clock skew between reaching the join point and sending the next task
# (it doesn't matter too much if we're a few ms off).
waiting_period = 1.0
# Some metrics store implementations return None because no external representation is required.
# pylint: disable=assignment-from-none
m = self.metrics_store.to_externalizable(clear=True)
self.driver_actor.on_task_finished(m, waiting_period)
# Using a perf_counter here is fine also in the distributed case as we subtract it from `master_received_msg_at` making it
# a relative instead of an absolute value.
start_next_task = time.perf_counter() + waiting_period
for worker_id, worker in enumerate(self.workers):
worker_ended_task_at, master_received_msg_at = workers_curr_step[worker_id]
worker_start_timestamp = worker_ended_task_at + (start_next_task - master_received_msg_at)
self.logger.debug(
"Scheduling next task for worker id [%d] at their timestamp [%f] (master timestamp [%f])",
worker_id,
worker_start_timestamp,
start_next_task,
)
self.driver_actor.drive_at(worker, worker_start_timestamp)
def may_complete_current_task(self, task_allocations):
any_joinpoints_completing_parent = [a for a in task_allocations if a.task.any_task_completes_parent]
joinpoints_completing_parent = [a for a in task_allocations if a.task.preceding_task_completes_parent]
# If 'completed-by' is set to 'any', then we *do* want to check for completion by
# any client and *not* wait until the remaining runner has completed. This way the 'parallel' task will exit
# on the completion of _any_ client for any task, i.e. given a contrived track with two tasks to execute inside
# a parallel block:
# * parallel:
# * bulk-1, with clients 8
# * bulk-2, with clients: 8
#
# 1. Both 'bulk-1' and 'bulk-2' execute in parallel
# 2. 'bulk-1' client[0]'s runner is first to complete and reach the next joinpoint successfully
# 3. 'bulk-1' will now cause the parent task to complete and _not_ wait for all 8 clients' runner to complete,
# or for 'bulk-2' at all
#
# The reasoning for the distinction between 'any_joinpoints_completing_parent' & 'joinpoints_completing_parent'
# is to simplify logic, otherwise we'd need to implement some form of state machine involving actor-to-actor
# communication.
if len(any_joinpoints_completing_parent) > 0 and not self.complete_current_task_sent:
self.logger.info(
"Any task before join point [%s] is able to complete the parent structure. Telling all clients to exit immediately.",
any_joinpoints_completing_parent[0].task,
)
self.complete_current_task_sent = True
for worker in self.workers:
self.driver_actor.complete_current_task(worker)
# If we have a specific 'completed-by' task specified, then we want to make sure that all clients for that task
# are able to complete their runners as expected before completing the parent
elif len(joinpoints_completing_parent) > 0 and not self.complete_current_task_sent:
# while this list could contain multiple items, it should always be the same task (but multiple
# different clients) so any item is sufficient.
current_join_point = joinpoints_completing_parent[0].task
self.logger.info(
"Tasks before join point [%s] are able to complete the parent structure. Checking "
"if all [%d] clients have finished yet.",
current_join_point,
len(current_join_point.clients_executing_completing_task),
)
pending_client_ids = []
for client_id in current_join_point.clients_executing_completing_task:
# We assume that all clients have finished if their corresponding worker has finished
worker_id = self.clients_per_worker[client_id]
if worker_id not in self.workers_completed_current_step:
pending_client_ids.append(client_id)
# are all clients executing said task already done? if so we need to notify the remaining clients
if len(pending_client_ids) == 0:
# As we are waiting for other clients to finish, we would send this message over and over again.
# Hence we need to memorize whether we have already sent it for the current step.
self.complete_current_task_sent = True
self.logger.info("All affected clients have finished. Notifying all clients to complete their current tasks.")
for worker in self.workers:
self.driver_actor.complete_current_task(worker)
else:
if len(pending_client_ids) > 32:
self.logger.info("[%d] clients did not yet finish.", len(pending_client_ids))
else:
self.logger.info("Client id(s) [%s] did not yet finish.", ",".join(map(str, pending_client_ids)))
def reset_relative_time(self):
self.logger.debug("Resetting relative time of request metrics store.")
self.metrics_store.reset_relative_time()
def finished(self):
return self.current_step == self.number_of_steps
def close(self):
self.progress_reporter.finish()
if self.metrics_store and self.metrics_store.opened:
self.metrics_store.close()
def update_samples(self, samples):
if len(samples) > 0:
self.raw_samples += samples
# We need to check all samples, they will be from different clients
for s in samples:
self.most_recent_sample_per_client[s.client_id] = s
def update_progress_message(self, task_finished=False):
if not self.quiet and self.current_step >= 0:
tasks = ",".join([t.name for t in self.tasks_per_join_point[self.current_step]])
if task_finished:
total_progress = 1.0
else:
# we only count clients which actually contribute to progress. If clients are executing tasks eternally in a parallel
# structure, we should not count them. The reason is that progress depends entirely on the client(s) that execute the
# task that is completing the parallel structure.
progress_per_client = [
s.percent_completed for s in self.most_recent_sample_per_client.values() if s.percent_completed is not None
]
num_clients = max(len(progress_per_client), 1)
total_progress = sum(progress_per_client) / num_clients
self.progress_reporter.print("Running %s" % tasks, "[%3d%% done]" % (round(total_progress * 100)))
if task_finished:
self.progress_reporter.finish()
def post_process_samples(self):
# we do *not* do this here to avoid concurrent updates (actors are single-threaded) but rather to make it clear that we use
# only a snapshot and that new data will go to a new sample set.
raw_samples = self.raw_samples
self.raw_samples = []
self.sample_post_processor(raw_samples)
class SamplePostprocessor:
def __init__(self, metrics_store, downsample_factor, track_meta_data, challenge_meta_data):
self.logger = logging.getLogger(__name__)
self.metrics_store = metrics_store
self.track_meta_data = track_meta_data
self.challenge_meta_data = challenge_meta_data
self.throughput_calculator = ThroughputCalculator()
self.downsample_factor = downsample_factor
def __call__(self, raw_samples):
if len(raw_samples) == 0:
return
total_start = time.perf_counter()
start = total_start
final_sample_count = 0
for idx, sample in enumerate(raw_samples):
if idx % self.downsample_factor == 0:
final_sample_count += 1
client_id_meta_data = {"client_id": sample.client_id}
meta_data = self.merge(
self.track_meta_data,
self.challenge_meta_data,
sample.operation_meta_data,
sample.task.meta_data,
sample.request_meta_data,
client_id_meta_data,
)
self.metrics_store.put_value_cluster_level(
name="latency",
value=convert.seconds_to_ms(sample.latency),
unit="ms",
task=sample.task.name,
operation=sample.operation_name,
operation_type=sample.operation_type,
sample_type=sample.sample_type,
absolute_time=sample.absolute_time,
relative_time=sample.relative_time,
meta_data=meta_data,
)
self.metrics_store.put_value_cluster_level(
name="service_time",
value=convert.seconds_to_ms(sample.service_time),
unit="ms",
task=sample.task.name,
operation=sample.operation_name,
operation_type=sample.operation_type,
sample_type=sample.sample_type,
absolute_time=sample.absolute_time,
relative_time=sample.relative_time,
meta_data=meta_data,
)
self.metrics_store.put_value_cluster_level(
name="processing_time",
value=convert.seconds_to_ms(sample.processing_time),
unit="ms",
task=sample.task.name,
operation=sample.operation_name,
operation_type=sample.operation_type,
sample_type=sample.sample_type,
absolute_time=sample.absolute_time,
relative_time=sample.relative_time,
meta_data=meta_data,
)
for timing in sample.dependent_timings:
self.metrics_store.put_value_cluster_level(
name="service_time",
value=convert.seconds_to_ms(timing.service_time),
unit="ms",
task=timing.task.name,
operation=timing.operation_name,
operation_type=timing.operation_type,
sample_type=timing.sample_type,
absolute_time=timing.absolute_time,
relative_time=timing.relative_time,
meta_data=self.merge(timing.request_meta_data, client_id_meta_data),
)
end = time.perf_counter()
self.logger.debug("Storing latency and service time took [%f] seconds.", (end - start))
start = end
aggregates = self.throughput_calculator.calculate(raw_samples)
end = time.perf_counter()
self.logger.debug("Calculating throughput took [%f] seconds.", (end - start))
start = end
for task, samples in aggregates.items():
meta_data = self.merge(self.track_meta_data, self.challenge_meta_data, task.operation.meta_data, task.meta_data)
for absolute_time, relative_time, sample_type, throughput, throughput_unit in samples:
self.metrics_store.put_value_cluster_level(
name="throughput",
value=throughput,
unit=throughput_unit,
task=task.name,
operation=task.operation.name,
operation_type=task.operation.type,
sample_type=sample_type,
absolute_time=absolute_time,
relative_time=relative_time,
meta_data=meta_data,
)
end = time.perf_counter()
self.logger.debug("Storing throughput took [%f] seconds.", (end - start))
start = end
# this will be a noop for the in-memory metrics store. If we use an ES metrics store however, this will ensure that we already send
# the data and also clear the in-memory buffer. This allows users to see data already while running the benchmark. In cases where
# it does not matter (i.e. in-memory) we will still defer this step until the end.
#
# Don't force refresh here in the interest of short processing times. We don't need to query immediately afterwards so there is
# no need for frequent refreshes.
self.metrics_store.flush(refresh=False)
end = time.perf_counter()
self.logger.debug("Flushing the metrics store took [%f] seconds.", (end - start))
self.logger.debug(
"Postprocessing [%d] raw samples (downsampled to [%d] samples) took [%f] seconds in total.",
len(raw_samples),
final_sample_count,
(end - total_start),
)
def merge(self, *args):
result = {}
for arg in args:
if arg is not None:
result.update(arg)
return result
def calculate_worker_assignments(host_configs, client_count):
"""
Assigns clients to workers on the provided hosts.
:param host_configs: A list of dicts where each dict contains the host name (key: ``host``) and the number of
available CPU cores (key: ``cores``).
:param client_count: The number of clients that should be used at most.
:return: A list of dicts containing the host (key: ``host``) and a list of workers (key ``workers``). Each entry
in that list contains another list with the clients that should be assigned to these workers.
"""
assignments = []
client_idx = 0
host_count = len(host_configs)
clients_per_host = math.ceil(client_count / host_count)
remaining_clients = client_count
for host_config in host_configs:
# the last host might not need to simulate as many clients as the rest of the hosts as we eagerly
# assign clients to hosts.
clients_on_this_host = min(clients_per_host, remaining_clients)
assignment = {
"host": host_config["host"],
"workers": [],
}
assignments.append(assignment)
workers_on_this_host = host_config["cores"]
clients_per_worker = [0] * workers_on_this_host
# determine how many clients each worker should simulate
for c in range(clients_on_this_host):
clients_per_worker[c % workers_on_this_host] += 1
# assign client ids to workers
for client_count_for_worker in clients_per_worker:
worker_assignment = []
assignment["workers"].append(worker_assignment)
for c in range(client_idx, client_idx + client_count_for_worker):
worker_assignment.append(c)
client_idx += client_count_for_worker
remaining_clients -= clients_on_this_host
assert remaining_clients == 0
return assignments
ClientAllocation = collections.namedtuple("ClientAllocation", ["client_id", "task"])
class ClientAllocations:
def __init__(self):
self.allocations = []
def add(self, client_id, tasks):
self.allocations.append({"client_id": client_id, "tasks": tasks})
def is_joinpoint(self, task_index):
return all(isinstance(t.task, JoinPoint) for t in self.tasks(task_index))
def tasks(self, task_index, remove_empty=True):
current_tasks = []
for allocation in self.allocations:
tasks_at_index = allocation["tasks"][task_index]
if remove_empty and tasks_at_index is not None:
current_tasks.append(ClientAllocation(allocation["client_id"], tasks_at_index))
return current_tasks
class Worker(actor.RallyActor):
"""
The actual worker that applies load against the cluster(s).
It will also regularly send measurements to the master node so it can consolidate them.
"""
WAKEUP_INTERVAL_SECONDS = 5
def __init__(self):
super().__init__()
self.driver_actor = None
self.worker_id = None
self.config: Optional[types.Config] = None
self.track = None
self.client_allocations = None
self.client_contexts = None
self.current_task_index = 0
self.next_task_index = 0
self.on_error = None
self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# cancellation via future does not work, hence we use our own mechanism with a shared variable and polling
self.cancel = threading.Event()
# used to indicate that we want to prematurely consider this completed. This is *not* due to cancellation
# but a regular event in a benchmark and used to model task dependency of parallel tasks.
self.complete = threading.Event()
self.executor_future = None
self.sampler = None
self.start_driving = False
self.wakeup_interval = Worker.WAKEUP_INTERVAL_SECONDS
self.sample_queue_size = None
@actor.no_retry("worker") # pylint: disable=no-value-for-parameter
def receiveMsg_Bootstrap(self, msg, sender):
self.driver_actor = sender
self.worker_id = msg.worker_id
# load node-specific config to have correct paths available
self.config = load_local_config(msg.config)
load_track(self.config, install_dependencies=False)
self.logger.debug("Worker[%d] has Python load path %s after bootstrap.", self.worker_id, sys.path)
@actor.no_retry("worker") # pylint: disable=no-value-for-parameter
def receiveMsg_StartWorker(self, msg, sender):
assert self.config is not None
self.logger.info("Worker[%d] is about to start.", msg.worker_id)
self.on_error = self.config.opts("driver", "on.error")
self.sample_queue_size = int(self.config.opts("reporting", "sample.queue.size", mandatory=False, default_value=1 << 20))
self.track = msg.track
track.set_absolute_data_path(self.config, self.track)
self.client_allocations = msg.client_allocations
self.client_contexts = msg.client_contexts
self.current_task_index = 0
self.cancel.clear()
# we need to wake up more often in test mode
if self.config.opts("track", "test.mode.enabled"):
self.wakeup_interval = 0.5
runner.register_default_runners(self.config)
if self.track.has_plugins:
track.load_track_plugins(self.config, self.track.name, runner.register_runner, scheduler.register_scheduler)
self.drive()
@actor.no_retry("worker") # pylint: disable=no-value-for-parameter
def receiveMsg_Drive(self, msg, sender):
sleep_time = datetime.timedelta(seconds=msg.client_start_timestamp - time.perf_counter())
self.logger.debug(
"Worker[%d] is continuing its work at task index [%d] on [%f], that is in [%s].",
self.worker_id,
self.current_task_index,
msg.client_start_timestamp,
sleep_time,
)
self.start_driving = True
self.wakeupAfter(sleep_time)
@actor.no_retry("worker") # pylint: disable=no-value-for-parameter
def receiveMsg_CompleteCurrentTask(self, msg, sender):
# finish now ASAP. Remaining samples will be sent with the next WakeupMessage. We will also need to skip to the next
# JoinPoint. But if we are already at a JoinPoint at the moment, there is nothing to do.
if self.at_joinpoint():
self.logger.info(
"Worker[%s] has received CompleteCurrentTask but is currently at join point at index [%d]. Ignoring.",
str(self.worker_id),
self.current_task_index,
)
else:
self.logger.info(
"Worker[%s] has received CompleteCurrentTask. Completing tasks at index [%d].", str(self.worker_id), self.current_task_index
)
self.complete.set()
@actor.no_retry("worker") # pylint: disable=no-value-for-parameter
def receiveMsg_WakeupMessage(self, msg, sender):
# it would be better if we could send ourselves a message at a specific time, simulate this with a boolean...
if self.start_driving:
self.start_driving = False
self.drive()
else:
current_samples = self.send_samples()
if self.cancel.is_set():
self.logger.info("Worker[%s] has detected that benchmark has been cancelled. Notifying master...", str(self.worker_id))
self.send(self.driver_actor, actor.BenchmarkCancelled())
elif self.executor_future is not None and self.executor_future.done():
e = self.executor_future.exception(timeout=0)
if e:
self.logger.exception(
"Worker[%s] has detected a benchmark failure. Notifying master...", str(self.worker_id), exc_info=e
)
# the exception might be user-defined and not be on the load path of the master driver. Hence, it cannot be
# deserialized on the receiver so we convert it here to a plain string.
self.send(self.driver_actor, actor.BenchmarkFailure(f"Error in load generator [{self.worker_id}]", str(e)))
else:
self.logger.debug("Worker[%s] is ready for the next task.", str(self.worker_id))
self.executor_future = None
self.drive()
else:
if current_samples and len(current_samples) > 0:
most_recent_sample = current_samples[-1]
if most_recent_sample.percent_completed is not None:
self.logger.debug(
"Worker[%s] is executing [%s] (%.2f%% complete).",
str(self.worker_id),
most_recent_sample.task,
most_recent_sample.percent_completed * 100.0,
)
else:
# TODO: This could be misleading given that one worker could execute more than one task...
self.logger.debug(
"Worker[%s] is executing [%s] (dependent eternal task).", str(self.worker_id), most_recent_sample.task
)
else:
self.logger.debug("Worker[%s] is executing (no samples).", str(self.worker_id))
self.wakeupAfter(datetime.timedelta(seconds=self.wakeup_interval))
def receiveMsg_ActorExitRequest(self, msg, sender):
self.logger.debug("Worker[%s] has received ActorExitRequest.", str(self.worker_id))
if self.executor_future is not None and self.executor_future.running():
self.cancel.set()
self.pool.shutdown()
self.logger.debug("Worker[%s] is exiting due to ActorExitRequest.", str(self.worker_id))
def receiveMsg_BenchmarkFailure(self, msg, sender):
# sent by our no_retry infrastructure; forward to master
self.send(self.driver_actor, msg)
def receiveUnrecognizedMessage(self, msg, sender):
self.logger.debug("Worker[%d] received unknown message [%s] (ignoring).", self.worker_id, str(msg))
def drive(self):
assert self.config is not None
task_allocations = self.current_tasks_and_advance()
# skip non-tasks in the task list
while len(task_allocations) == 0:
task_allocations = self.current_tasks_and_advance()
if self.at_joinpoint():
self.logger.debug("Worker[%d] reached join point at index [%d].", self.worker_id, self.current_task_index)
# clients that don't execute tasks don't need to care about waiting
if self.executor_future is not None:
self.executor_future.result()
self.send_samples()
self.cancel.clear()
self.complete.clear()
self.executor_future = None
self.sampler = None
self.send(self.driver_actor, JoinPointReached(self.worker_id, task_allocations))
else:
# There may be a situation where there are more (parallel) tasks than workers. If we were asked to complete all tasks, we not
# only need to complete actively running tasks but actually all scheduled tasks until we reach the next join point.
if self.complete.is_set():
self.logger.info(
"Worker[%d] skips tasks at index [%d] because it has been asked to complete all tasks until next join point.",
self.worker_id,
self.current_task_index,
)
else:
self.logger.debug("Worker[%d] is executing tasks at index [%d].", self.worker_id, self.current_task_index)
self.sampler = Sampler(start_timestamp=time.perf_counter(), buffer_size=self.sample_queue_size)
executor = AsyncIoAdapter(
self.config,
self.track,
task_allocations,
self.sampler,
self.cancel,
self.complete,
self.on_error,
self.client_contexts,
self.worker_id,
)
self.executor_future = self.pool.submit(executor)
self.wakeupAfter(datetime.timedelta(seconds=self.wakeup_interval))
def at_joinpoint(self):
return self.client_allocations.is_joinpoint(self.current_task_index)
def current_tasks_and_advance(self):
self.current_task_index = self.next_task_index
current = self.client_allocations.tasks(self.current_task_index)
self.next_task_index += 1
self.logger.debug("Worker[%d] is at task index [%d].", self.worker_id, self.current_task_index)
return current
def send_samples(self):
if self.sampler:
samples = self.sampler.samples
if len(samples) > 0:
self.send(self.driver_actor, UpdateSamples(self.worker_id, samples))
return samples
return None
class Sampler:
"""
Encapsulates management of gathered samples.
"""
def __init__(self, start_timestamp, buffer_size=16384):
self.start_timestamp = start_timestamp
self.q = queue.Queue(maxsize=buffer_size)
self.logger = logging.getLogger(__name__)
def add(
self,
task,
client_id,
sample_type,
meta_data,
absolute_time,
request_start,
latency,
service_time,
processing_time,
throughput,
ops,
ops_unit,
time_period,
percent_completed,
dependent_timing=None,
):
try:
self.q.put_nowait(
Sample(
client_id,
absolute_time,
request_start,
self.start_timestamp,
task,
sample_type,
meta_data,
latency,
service_time,
processing_time,
throughput,
ops,
ops_unit,
time_period,
percent_completed,
dependent_timing,
)
)
except queue.Full:
self.logger.warning("Dropping sample for [%s] due to a full sampling queue.", task.operation.name)
@property
def samples(self):
samples = []
try:
while True:
samples.append(self.q.get_nowait())
except queue.Empty:
pass
return samples
class Sample:
def __init__(
self,
client_id,
absolute_time,
request_start,
task_start,
task,
sample_type,
request_meta_data,
latency,
service_time,
processing_time,
throughput,
total_ops,
total_ops_unit,
time_period,
percent_completed,
dependent_timing=None,
operation_name=None,
operation_type=None,
):
self.client_id = client_id
self.absolute_time = absolute_time
self.request_start = request_start
self.task_start = task_start
self.task = task
self.sample_type = sample_type
self.request_meta_data = request_meta_data
self.latency = latency
self.service_time = service_time
self.processing_time = processing_time
self.throughput = throughput
self.total_ops = total_ops
self.total_ops_unit = total_ops_unit
self.time_period = time_period
self._dependent_timing = dependent_timing
self._operation_name = operation_name
self._operation_type = operation_type
# may be None for eternal tasks!
self.percent_completed = percent_completed
@property
def operation_name(self):
return self._operation_name if self._operation_name else self.task.operation.name
@property
def operation_type(self):
return self._operation_type if self._operation_type else self.task.operation.type
@property
def operation_meta_data(self):
return self.task.operation.meta_data
@property
def relative_time(self):
return self.request_start - self.task_start
@property
def dependent_timings(self):
if self._dependent_timing:
for t in self._dependent_timing:
timing = t.pop("dependent_timing")
meta_data = self._merge(self.request_meta_data, t)
yield Sample(
self.client_id,
timing["absolute_time"],
timing["request_start"],
self.task_start,
self.task,
self.sample_type,
meta_data,
0,
timing["service_time"],
0,
0,
self.total_ops,
self.total_ops_unit,
self.time_period,
self.percent_completed,
None,
timing["operation"],
timing["operation-type"],
)
def __repr__(self, *args, **kwargs):
return (
f"[{self.absolute_time}; {self.relative_time}] [client [{self.client_id}]] [{self.task}] "
f"[{self.sample_type}]: [{self.latency}s] request latency, [{self.service_time}s] service time, "
f"[{self.total_ops} {self.total_ops_unit}]"
)
def _merge(self, *args):
result = {}
for arg in args:
if arg is not None:
result.update(arg)
return result
def select_challenge(config: types.Config, t):
challenge_name = config.opts("track", "challenge.name")
selected_challenge = t.find_challenge_or_default(challenge_name)
if not selected_challenge:
raise exceptions.SystemSetupError(
"Unknown challenge [%s] for track [%s]. You can list the available tracks and their "
"challenges with %s list tracks." % (challenge_name, t.name, PROGRAM_NAME)
)
return selected_challenge
class ThroughputCalculator:
class TaskStats:
"""
Stores per task numbers needed for throughput calculation in between multiple calculations.
"""
def __init__(self, bucket_interval, sample_type, start_time):
self.unprocessed = []
self.total_count = 0
self.interval = 0
self.bucket_interval = bucket_interval
# the first bucket is complete after one bucket interval is over
self.bucket = bucket_interval
self.sample_type = sample_type
self.has_samples_in_sample_type = False
# start relative to the beginning of our (calculation) time slice.
self.start_time = start_time
@property
def throughput(self):
return self.total_count / self.interval
def maybe_update_sample_type(self, current_sample_type):
if self.sample_type < current_sample_type:
self.sample_type = current_sample_type
self.has_samples_in_sample_type = False
def update_interval(self, absolute_sample_time):
self.interval = max(absolute_sample_time - self.start_time, self.interval)
def can_calculate_throughput(self):
return self.interval > 0 and self.interval >= self.bucket
def can_add_final_throughput_sample(self):
return self.interval > 0 and not self.has_samples_in_sample_type
def finish_bucket(self, new_total):
self.unprocessed = []
self.total_count = new_total
self.has_samples_in_sample_type = True
self.bucket = int(self.interval) + self.bucket_interval
def __init__(self):
self.task_stats = {}
def calculate(self, samples, bucket_interval_secs=1):
"""
Calculates global throughput based on samples gathered from multiple load generators.
:param samples: A list containing all samples from all load generators.
:param bucket_interval_secs: The bucket interval for aggregations.
:return: A global view of throughput samples.
"""
samples_per_task = {}
# first we group all samples by task (operation).
for sample in samples:
k = sample.task
if k not in samples_per_task:
samples_per_task[k] = []
samples_per_task[k].append(sample)
global_throughput = {}
# with open("raw_samples_new.csv", "a") as sample_log:
# print("client_id,absolute_time,relative_time,operation,sample_type,total_ops,time_period", file=sample_log)
for k, v in samples_per_task.items():
task = k
if task not in global_throughput:
global_throughput[task] = []
# sort all samples by time
if task in self.task_stats:
samples = itertools.chain(v, self.task_stats[task].unprocessed)
else:
samples = v
current_samples = sorted(samples, key=lambda s: s.absolute_time)
# Calculate throughput based on service time if the runner does not provide one, otherwise use it as is and
# only transform the values into the expected structure.
first_sample = current_samples[0]
if first_sample.throughput is None:
task_throughput = self.calculate_task_throughput(task, current_samples, bucket_interval_secs)
else:
task_throughput = self.map_task_throughput(current_samples)
global_throughput[task].extend(task_throughput)
return global_throughput
def calculate_task_throughput(self, task, current_samples, bucket_interval_secs):
task_throughput = []
if task not in self.task_stats:
first_sample = current_samples[0]
self.task_stats[task] = ThroughputCalculator.TaskStats(
bucket_interval=bucket_interval_secs,
sample_type=first_sample.sample_type,
start_time=first_sample.absolute_time - first_sample.time_period,
)
current = self.task_stats[task]
count = current.total_count
last_sample = None
for sample in current_samples:
last_sample = sample
# print("%d,%f,%f,%s,%s,%d,%f" %
# (sample.client_id, sample.absolute_time, sample.relative_time, sample.operation, sample.sample_type,
# sample.total_ops, sample.time_period), file=sample_log)
# once we have seen a new sample type, we stick to it.
current.maybe_update_sample_type(sample.sample_type)
# we need to store the total count separately and cannot update `current.total_count` immediately here
# because we would count all raw samples in `unprocessed` twice. Hence, we'll only update
# `current.total_count` when we have calculated a new throughput sample.
count += sample.total_ops
current.update_interval(sample.absolute_time)
if current.can_calculate_throughput():
current.finish_bucket(count)
task_throughput.append(
(
sample.absolute_time,
sample.relative_time,
current.sample_type,
current.throughput,
# we calculate throughput per second
f"{sample.total_ops_unit}/s",
)
)
else:
current.unprocessed.append(sample)
# also include the last sample if we don't have one for the current sample type, even if it is below the bucket
# interval (mainly needed to ensure we show throughput data in test mode)
if last_sample is not None and current.can_add_final_throughput_sample():
current.finish_bucket(count)
task_throughput.append(
(
last_sample.absolute_time,
last_sample.relative_time,
current.sample_type,
current.throughput,
f"{last_sample.total_ops_unit}/s",
)
)
return task_throughput
def map_task_throughput(self, current_samples):
throughput = []
for sample in current_samples:
throughput.append(
(
sample.absolute_time,
sample.relative_time,
sample.sample_type,
sample.throughput,
f"{sample.total_ops_unit}/s",
)
)
return throughput
class AsyncIoAdapter:
def __init__(self, cfg: types.Config, track, task_allocations, sampler, cancel, complete, abort_on_error, client_contexts, worker_id):
self.cfg = cfg
self.track = track
self.task_allocations = task_allocations
self.sampler = sampler
self.cancel = cancel
self.complete = complete
self.abort_on_error = abort_on_error
self.client_contexts = client_contexts
self.parent_worker_id = worker_id
self.profiling_enabled = self.cfg.opts("driver", "profiling")
self.assertions_enabled = self.cfg.opts("driver", "assertions")
self.debug_event_loop = self.cfg.opts("system", "async.debug", mandatory=False, default_value=False)
self.logger = logging.getLogger(__name__)
def __call__(self, *args, **kwargs):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
loop.set_debug(self.debug_event_loop)
loop.set_exception_handler(self._logging_exception_handler)
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.run())
finally:
loop.close()
def _logging_exception_handler(self, loop, context):
self.logger.error("Uncaught exception in event loop: %s", context)
async def run(self):
def es_clients(client_id, all_hosts, all_client_options, distribution_version, distribution_flavor):
es = {}
context = self.client_contexts.get(client_id)
api_key = context.api_key
for cluster_name, cluster_hosts in all_hosts.items():
es[cluster_name] = client.EsClientFactory(
cluster_hosts,
all_client_options[cluster_name],
distribution_version=distribution_version,
distribution_flavor=distribution_flavor,
).create_async(api_key=api_key, client_id=client_id)
return es
if self.assertions_enabled:
self.logger.info("Task assertions enabled")
runner.enable_assertions(self.assertions_enabled)
clients = []
awaitables = []
# A parameter source should only be created once per task - it is partitioned later on per client.
params_per_task = {}
for client_id, task_allocation in self.task_allocations:
task = task_allocation.task
if task not in params_per_task:
param_source = track.operation_parameters(self.track, task)
params_per_task[task] = param_source
schedule = schedule_for(task_allocation, params_per_task[task])
es = es_clients(
client_id,
self.cfg.opts("client", "hosts").all_hosts,
self.cfg.opts("client", "options"),
self.cfg.opts("mechanic", "distribution.version", mandatory=False),
self.cfg.opts("mechanic", "distribution.flavor", mandatory=False),
)
clients.append(es)
async_executor = AsyncExecutor(
client_id, task, schedule, es, self.sampler, self.cancel, self.complete, task.error_behavior(self.abort_on_error)
)
final_executor = AsyncProfiler(async_executor) if self.profiling_enabled else async_executor
awaitables.append(final_executor())
task_names = [t.task.task.name for t in self.task_allocations]
self.logger.info("Worker[%s] executing tasks: %s", self.parent_worker_id, task_names)
run_start = time.perf_counter()
try:
_ = await asyncio.gather(*awaitables)
finally:
run_end = time.perf_counter()
self.logger.info(
"Worker[%s] finished executing tasks %s in %f seconds", self.parent_worker_id, task_names, (run_end - run_start)
)
await asyncio.get_event_loop().shutdown_asyncgens()
shutdown_asyncgens_end = time.perf_counter()
self.logger.debug("Total time to shutdown asyncgens: %f seconds.", (shutdown_asyncgens_end - run_end))
for c in clients:
for es in c.values():
await es.close()
transport_close_end = time.perf_counter()
self.logger.debug("Total time to close transports: %f seconds.", (transport_close_end - shutdown_asyncgens_end))
class AsyncProfiler:
def __init__(self, target):
"""
:param target: The actual executor which should be profiled.
"""
self.target = target
self.profile_logger = logging.getLogger("rally.profile")
async def __call__(self, *args, **kwargs):
# initialize lazily, we don't need it in the majority of cases
# pylint: disable=import-outside-toplevel
import io as python_io
import yappi
yappi.start()
try:
return await self.target(*args, **kwargs)
finally:
yappi.stop()
s = python_io.StringIO()
yappi.get_func_stats().print_all(
out=s, columns={0: ("name", 140), 1: ("ncall", 8), 2: ("tsub", 8), 3: ("ttot", 8), 4: ("tavg", 8)}
)
profile = "\n=== Profile START ===\n"
profile += s.getvalue()
profile += "=== Profile END ==="
self.profile_logger.info(profile)
class AsyncExecutor:
def __init__(self, client_id, task, schedule, es, sampler, cancel, complete, on_error):
"""
Executes tasks according to the schedule for a given operation.
:param task: The task that is executed.
:param schedule: The schedule for this task.
:param es: Elasticsearch client that will be used to execute the operation.
:param sampler: A container to store raw samples.
:param cancel: A shared boolean that indicates we need to cancel execution.
:param complete: A shared boolean that indicates we need to prematurely complete execution.
:param on_error: A string specifying how the load generator should behave on errors.
"""
self.client_id = client_id
self.task = task
self.op = task.operation
self.schedule_handle = schedule
self.es = es
self.sampler = sampler
self.cancel = cancel
self.complete = complete
self.on_error = on_error
self.logger = logging.getLogger(__name__)
async def __call__(self, *args, **kwargs):
any_task_completes_parent = self.task.any_completes_parent
task_completes_parent = self.task.completes_parent
total_start = time.perf_counter()
# lazily initialize the schedule
self.logger.debug("Initializing schedule for client id [%s].", self.client_id)
schedule = self.schedule_handle()
# Start the schedule's timer early so the warmup period is independent of any deferred start due to ramp-up
self.schedule_handle.start()
rampup_wait_time = self.schedule_handle.ramp_up_wait_time
if rampup_wait_time:
self.logger.debug("client id [%s] waiting [%.2f]s for ramp-up.", self.client_id, rampup_wait_time)
await asyncio.sleep(rampup_wait_time)
self.logger.debug("Entering main loop for client id [%s].", self.client_id)
# noinspection PyBroadException
try:
async for expected_scheduled_time, sample_type, percent_completed, runner, params in schedule:
if self.cancel.is_set():
self.logger.info("User cancelled execution.")
break
absolute_expected_schedule_time = total_start + expected_scheduled_time
throughput_throttled = expected_scheduled_time > 0
if throughput_throttled:
rest = absolute_expected_schedule_time - time.perf_counter()
if rest > 0:
await asyncio.sleep(rest)
absolute_processing_start = time.time()
processing_start = time.perf_counter()
self.schedule_handle.before_request(processing_start)
with self.es["default"].new_request_context() as request_context:
total_ops, total_ops_unit, request_meta_data = await execute_single(runner, self.es, params, self.on_error)
request_start = request_context.request_start
request_end = request_context.request_end
processing_end = time.perf_counter()
service_time = request_end - request_start
processing_time = processing_end - processing_start
time_period = request_end - total_start
self.schedule_handle.after_request(processing_end, total_ops, total_ops_unit, request_meta_data)
# Allow runners to override the throughput calculation in very specific circumstances. Usually, Rally
# assumes that throughput is the "amount of work" (determined by the "weight") per unit of time
# (determined by the elapsed time period). However, in certain cases (e.g. shard recovery or other
# long running operations where there is a dedicated stats API to determine progress), it is
# advantageous if the runner calculates throughput directly. The following restrictions apply:
#
# * Only one client must call that runner (when throughput is calculated, it is aggregated across
# all clients but if the runner provides it, we take the value as is).
# * The runner should be rate-limited as each runner call will result in one throughput sample.
#
throughput = request_meta_data.pop("throughput", None)
# Do not calculate latency separately when we run unthrottled. This metric is just confusing then.
latency = request_end - absolute_expected_schedule_time if throughput_throttled else service_time
# If this task completes the parent task we should *not* check for completion by another client but
# instead continue until our own runner has completed. We need to do this because the current
# worker (process) could run multiple clients that execute the same task. We do not want all clients to
# finish this task as soon as the first of these clients has finished but rather continue until the last
# client has finished that task.
if task_completes_parent:
completed = runner.completed
else:
completed = self.complete.is_set() or runner.completed
# last sample should bump progress to 100% if externally completed.
if completed:
progress = 1.0
elif runner.percent_completed:
progress = runner.percent_completed
else:
progress = percent_completed
self.sampler.add(
self.task,
self.client_id,
sample_type,
request_meta_data,
absolute_processing_start,
request_start,
latency,
service_time,
processing_time,
throughput,
total_ops,
total_ops_unit,
time_period,
progress,
request_meta_data.pop("dependent_timing", None),
)
if completed:
self.logger.info("Task [%s] is considered completed due to external event.", self.task)
break
except BaseException as e:
self.logger.exception("Could not execute schedule")
raise exceptions.RallyError(f"Cannot run task [{self.task}]: {e}") from None
finally:
# Actively set it if this task completes its parent
if task_completes_parent:
self.logger.info(
"Task [%s] completes parent. Client id [%s] is finished executing it and signals completion.",
self.task,
self.client_id,
)
self.complete.set()
elif any_task_completes_parent:
self.logger.info(
"Task [%s] completes parent. Client id [%s] is finished executing it and signals completion of all "
"remaining clients, immediately.",
self.task,
self.client_id,
)
self.complete.set()
async def execute_single(runner, es, params, on_error):
"""
Invokes the given runner once and provides the runner's return value in a uniform structure.
:return: a triple of: total number of operations, unit of operations, a dict of request meta data (may be None).
"""
# pylint: disable=import-outside-toplevel
import elasticsearch
fatal_error = False
try:
async with runner:
return_value = await runner(es, params)
if isinstance(return_value, tuple) and len(return_value) == 2:
total_ops, total_ops_unit = return_value
request_meta_data = {"success": True}
elif isinstance(return_value, dict):
total_ops = return_value.pop("weight", 1)
total_ops_unit = return_value.pop("unit", "ops")
request_meta_data = return_value
if "success" not in request_meta_data:
request_meta_data["success"] = True
else:
total_ops = 1
total_ops_unit = "ops"
request_meta_data = {"success": True}
except elasticsearch.TransportError as e:
# we *specifically* want to distinguish connection refused (a node died?) from connection timeouts
# pylint: disable=unidiomatic-typecheck
if type(e) is elasticsearch.ConnectionError:
fatal_error = True
total_ops = 0
total_ops_unit = "ops"
request_meta_data = {"success": False, "error-type": "transport"}
# For the 'errors' attribute, errors are ordered from
# most recently raised (index=0) to least recently raised (index=N)
#
# If an HTTP status code is available with the error it
# will be stored under 'status'. If HTTP headers are available
# they are stored under 'headers'.
if e.errors:
if hasattr(e.errors[0], "status"):
request_meta_data["http-status"] = e.errors[0].status
# connection timeout errors don't provide a helpful description
if isinstance(e, elasticsearch.ConnectionTimeout):
request_meta_data["error-description"] = "network connection timed out"
else:
error_description = e.message
request_meta_data["error-description"] = error_description
except elasticsearch.ApiError as e:
total_ops = 0
total_ops_unit = "ops"
request_meta_data = {"success": False, "error-type": "api"}
error_message = ""
# Some runners return a raw response, causing the 'error' property to be a string literal of the bytes/BytesIO object,
# we should avoid bubbling that up
# e.g. ApiError(413, '<_io.BytesIO object at 0xffffaf146a70>')
if isinstance(e.body, bytes):
# could be an empty body
if error_body := e.body.decode("utf-8"):
error_message = error_body
else:
# to be consistent with an empty 'e.error'
error_message = str(None)
elif isinstance(e.body, BytesIO):
# could be an empty body
if error_body := e.body.read().decode("utf-8"):
error_message = error_body
else:
# to be consistent with an empty 'e.error'
error_message = str(None)
# fallback to 'error' property if the body isn't bytes/BytesIO
else:
if isinstance(e.error, bytes):
error_message = e.error.decode("utf-8")
elif isinstance(e.error, BytesIO):
error_message = e.error.read().decode("utf-8")
else:
# if the 'error' is empty, we get back str(None)
error_message = e.error
if isinstance(e.info, bytes):
error_info = e.info.decode("utf-8")
elif isinstance(e.info, BytesIO):
error_info = e.info.read().decode("utf-8")
else:
error_info = e.info
if error_info:
error_message += f" ({error_info})"
error_description = error_message
request_meta_data["error-description"] = error_description
if e.status_code:
request_meta_data["http-status"] = e.status_code
except KeyError as e:
logging.getLogger(__name__).exception("Cannot execute runner [%s]; most likely due to missing parameters.", str(runner))
msg = "Cannot execute [%s]. Provided parameters are: %s. Error: [%s]." % (str(runner), list(params.keys()), str(e))
raise exceptions.SystemSetupError(msg)
if not request_meta_data["success"]:
if on_error == "abort" or fatal_error:
msg = "Request returned an error. Error type: %s" % request_meta_data.get("error-type", "Unknown")
if description := request_meta_data.get("error-description"):
msg += f", Description: {description}"
if http_status := request_meta_data.get("http-status"):
msg += f", HTTP Status: {http_status}"
raise exceptions.RallyAssertionError(msg)
return total_ops, total_ops_unit, request_meta_data
class JoinPoint:
def __init__(self, id, clients_executing_completing_task=None, any_task_completes_parent=None):
"""
:param id: The join point's id.
:param clients_executing_completing_task: An array of client indices which execute a task that can prematurely complete its parent
element. Provide 'None' or an empty array if no task satisfies this predicate.
"""
if clients_executing_completing_task is None:
clients_executing_completing_task = []
if any_task_completes_parent is None:
any_task_completes_parent = []
self.id = id
self.any_task_completes_parent = any_task_completes_parent
self.clients_executing_completing_task = clients_executing_completing_task
self.num_clients_executing_completing_task = len(clients_executing_completing_task)
self.preceding_task_completes_parent = self.num_clients_executing_completing_task > 0
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
return isinstance(other, type(self)) and self.id == other.id
def __repr__(self, *args, **kwargs):
return "JoinPoint(%s)" % self.id
class TaskAllocation:
def __init__(self, task, client_index_in_task, global_client_index, total_clients):
"""
:param task: The current task which is always a leaf task.
:param client_index_in_task: The task-specific index for the allocated client.
:param global_client_index: The globally unique index for the allocated client across
all concurrently executed tasks.
:param total_clients: The total number of clients executing tasks concurrently.
"""
self.task = task
self.client_index_in_task = client_index_in_task
self.global_client_index = global_client_index
self.total_clients = total_clients
def __hash__(self):
return hash(self.task) ^ hash(self.global_client_index)
def __eq__(self, other):
return isinstance(other, type(self)) and self.task == other.task and self.global_client_index == other.global_client_index
def __repr__(self, *args, **kwargs):
return (
f"TaskAllocation [{self.client_index_in_task}/{self.task.clients}] for {self.task} "
f"and [{self.global_client_index}/{self.total_clients}] in total"
)
class Allocator:
"""
Decides which operations runs on which client and how to partition them.
"""
def __init__(self, schedule):
self.schedule = schedule
@property
def allocations(self):
"""
Calculates an allocation matrix consisting of two dimensions. The first dimension is the client. The second dimension are the task
this client needs to run. The matrix shape is rectangular (i.e. it is not ragged). There are three types of entries in the matrix:
1. Normal tasks: They need to be executed by a client.
2. Join points: They are used as global coordination points which all clients need to reach until the benchmark can go on. They
indicate that a client has to wait until the master signals it can go on.
3. `None`: These are inserted by the allocator to keep the allocation matrix rectangular. Clients have to skip `None` entries
until one of the other entry types are encountered.
:return: An allocation matrix with the structure described above.
"""
max_clients = self.clients
allocations = [None] * max_clients
for client_index in range(max_clients):
allocations[client_index] = []
join_point_id = 0
# start with an artificial join point to allow master to coordinate that all clients start at the same time
next_join_point = JoinPoint(join_point_id)
for client_index in range(max_clients):
allocations[client_index].append(next_join_point)
join_point_id += 1
for task in self.schedule:
start_client_index = 0
clients_executing_completing_task = []
any_task_completes_parent = []
for sub_task in task:
for client_index in range(start_client_index, start_client_index + sub_task.clients):
# this is the actual client that will execute the task. It may differ from the logical one in case we over-commit (i.e.
# more tasks than actually available clients)
physical_client_index = client_index % max_clients
if sub_task.completes_parent:
clients_executing_completing_task.append(physical_client_index)
elif sub_task.any_completes_parent:
any_task_completes_parent.append(physical_client_index)
ta = TaskAllocation(
task=sub_task,
client_index_in_task=client_index - start_client_index,
global_client_index=client_index,
# if task represents a parallel structure this is the total number of clients
# executing sub-tasks concurrently.
total_clients=task.clients,
)
allocations[physical_client_index].append(ta)
start_client_index += sub_task.clients
# uneven distribution between tasks and clients, e.g. there are 5 (parallel) tasks but only 2 clients. Then, one of them
# executes three tasks, the other one only two. So we need to fill in a `None` for the second one.
if start_client_index % max_clients > 0:
# pin the index range to [0, max_clients). This simplifies the code below.
start_client_index = start_client_index % max_clients
for client_index in range(start_client_index, max_clients):
allocations[client_index].append(None)
# let all clients join after each task, then we go on
next_join_point = JoinPoint(join_point_id, clients_executing_completing_task, any_task_completes_parent)
for client_index in range(max_clients):
allocations[client_index].append(next_join_point)
join_point_id += 1
return allocations
@property
def join_points(self):
"""
:return: A list of all join points for this allocations.
"""
return [allocation for allocation in self.allocations[0] if isinstance(allocation, JoinPoint)]
@property
def tasks_per_joinpoint(self):
"""
Calculates a flat list of all tasks that are run in between join points.
Consider the following schedule (2 clients):
1. task1 and task2 run by both clients in parallel
2. join point
3. task3 run by client 1
4. join point
The results in: [{task1, task2}, {task3}]
:return: A list of sets containing all tasks.
"""
tasks = []
current_tasks = set()
allocs = self.allocations
# assumption: the shape of allocs is rectangular (i.e. each client contains the same number of elements)
for idx in range(0, len(allocs[0])):
for client in range(0, self.clients):
allocation = allocs[client][idx]
if isinstance(allocation, TaskAllocation):
current_tasks.add(allocation.task)
elif isinstance(allocation, JoinPoint) and len(current_tasks) > 0:
tasks.append(current_tasks)
current_tasks = set()
return tasks
@property
def clients(self):
"""
:return: The maximum number of clients involved in executing the given schedule.
"""
max_clients = 1
for task in self.schedule:
max_clients = max(max_clients, task.clients)
return max_clients
#######################################
#
# Scheduler related stuff
#
#######################################
# Runs a concrete schedule on one worker client
# Needs to determine the runners and concrete iterations per client.
def schedule_for(task_allocation, parameter_source):
"""
Calculates a client's schedule for a given task.
:param task_allocation: The task allocation that should be executed by this schedule.
:param parameter_source: The parameter source that should be used for this task.
:return: A generator for the operations the given client needs to perform for this task.
"""
logger = logging.getLogger(__name__)
task = task_allocation.task
op = task.operation
sched = scheduler.scheduler_for(task)
# We cannot use the global client index here because we need to support parallel execution of tasks
# with multiple clients. Consider the following scenario:
#
# * Clients 0-3 bulk index into indexA
# * Clients 4-7 bulk index into indexB
#
# Now we need to ensure that we start partitioning parameters correctly in both cases. And that means we
# need to start from (client) index 0 in both cases instead of 0 for indexA and 4 for indexB.
client_index = task_allocation.client_index_in_task
# guard all logging statements with the client index and only emit them for the first client. This information is
# repetitive and may cause issues in thespian with many clients (an excessive number of actor messages is sent).
if client_index == 0:
logger.debug("Choosing [%s] for [%s].", sched, task)
runner_for_op = runner.runner_for(op.type)
params_for_op = parameter_source.partition(client_index, task.clients)
if hasattr(sched, "parameter_source"):
if client_index == 0:
logger.debug("Setting parameter source [%s] for scheduler [%s]", params_for_op, sched)
sched.parameter_source = params_for_op
if requires_time_period_schedule(task, runner_for_op, params_for_op):
warmup_time_period = task.warmup_time_period if task.warmup_time_period else 0
if client_index == 0:
logger.info(
"Creating time-period based schedule with [%s] distribution for [%s] with a warmup period of [%s] "
"seconds and a time period of [%s] seconds.",
task.schedule,
task.name,
str(warmup_time_period),
str(task.time_period),
)
loop_control = TimePeriodBased(warmup_time_period, task.time_period)
else:
warmup_iterations = task.warmup_iterations if task.warmup_iterations else 0
if task.iterations:
iterations = task.iterations
elif params_for_op.infinite:
# this is usually the case if the parameter source provides a constant
iterations = 1
else:
iterations = None
if client_index == 0:
logger.debug(
"Creating iteration-count based schedule with [%s] distribution for [%s] with [%s] warmup "
"iterations and [%s] iterations.",
task.schedule,
task.name,
str(warmup_iterations),
str(iterations),
)
loop_control = IterationBased(warmup_iterations, iterations)
if client_index == 0:
if loop_control.infinite:
logger.debug("Parameter source will determine when the schedule for [%s] terminates.", task.name)
else:
logger.debug("%s schedule will determine when the schedule for [%s] terminates.", str(loop_control), task.name)
return ScheduleHandle(task_allocation, sched, loop_control, runner_for_op, params_for_op)
def requires_time_period_schedule(task, task_runner, params):
if task.warmup_time_period is not None or task.time_period is not None:
return True
# user has explicitly requested iterations
if task.warmup_iterations is not None or task.iterations is not None:
return False
# the runner determines completion
if task_runner.completed is not None:
return True
# If the parameter source ends after a finite amount of iterations, we will run with a time-based schedule
return not params.infinite
class ScheduleHandle:
def __init__(self, task_allocation, sched, task_progress_control, runner, params):
"""
Creates a generator that will yield individual task invocations for the provided schedule.
:param task_allocation: The task allocation for which the schedule is generated.
:param sched: The scheduler for this task.
:param task_progress_control: Controls how and how often this generator will loop.
:param runner: The runner for a given operation.
:param params: The parameter source for a given operation.
:return: A generator for the corresponding parameters.
"""
self.task_allocation = task_allocation
self.operation_type = task_allocation.task.operation.type
self.sched = sched
self.task_progress_control = task_progress_control
self.runner = runner
self.params = params
# TODO: Can we offload the parameter source execution to a different thread / process? Is this too heavy-weight?
# from concurrent.futures import ThreadPoolExecutor
# import asyncio
# self.io_pool_exc = ThreadPoolExecutor(max_workers=1)
# self.loop = asyncio.get_event_loop()
@property
def ramp_up_wait_time(self):
"""
:return: the number of seconds to wait until this client should start so load can gradually ramp-up.
"""
ramp_up_time_period = self.task_allocation.task.ramp_up_time_period
if ramp_up_time_period:
return ramp_up_time_period * (self.task_allocation.global_client_index / self.task_allocation.total_clients)
else:
return 0
def start(self):
self.task_progress_control.start()
def before_request(self, now):
self.sched.before_request(now)
def after_request(self, now, weight, unit, request_meta_data):
self.sched.after_request(now, weight, unit, request_meta_data)
def params_with_operation_type(self):
p = self.params.params()
p.update({"operation-type": self.operation_type})
return p
async def __call__(self):
next_scheduled = 0
if self.task_progress_control.infinite:
param_source_knows_progress = hasattr(self.params, "percent_completed")
while True:
try:
next_scheduled = self.sched.next(next_scheduled)
# does not contribute at all to completion. Hence, we cannot define completion.
percent_completed = self.params.percent_completed if param_source_knows_progress else None
# current_params = await self.loop.run_in_executor(self.io_pool_exc, self.params.params)
yield (
next_scheduled,
self.task_progress_control.sample_type,
percent_completed,
self.runner,
self.params_with_operation_type(),
)
self.task_progress_control.next()
except StopIteration:
return
else:
while not self.task_progress_control.completed:
try:
next_scheduled = self.sched.next(next_scheduled)
# current_params = await self.loop.run_in_executor(self.io_pool_exc, self.params.params)
yield (
next_scheduled,
self.task_progress_control.sample_type,
self.task_progress_control.percent_completed,
self.runner,
self.params_with_operation_type(),
)
self.task_progress_control.next()
except StopIteration:
return
class TimePeriodBased:
def __init__(self, warmup_time_period, time_period):
self._warmup_time_period = warmup_time_period
self._time_period = time_period
if warmup_time_period is not None and time_period is not None:
self._duration = self._warmup_time_period + self._time_period
else:
self._duration = None
self._start = None
self._now = None
def start(self):
self._now = time.perf_counter()
self._start = self._now
@property
def _elapsed(self):
return self._now - self._start
@property
def sample_type(self):
return metrics.SampleType.Warmup if self._elapsed < self._warmup_time_period else metrics.SampleType.Normal
@property
def infinite(self):
return self._time_period is None
@property
def percent_completed(self):
return self._elapsed / self._duration
@property
def completed(self):
return self._now >= (self._start + self._duration)
def next(self):
self._now = time.perf_counter()
def __str__(self):
return "time-period-based"
class IterationBased:
def __init__(self, warmup_iterations, iterations):
self._warmup_iterations = warmup_iterations
self._iterations = iterations
if warmup_iterations is not None and iterations is not None:
self._total_iterations = self._warmup_iterations + self._iterations
if self._total_iterations == 0:
raise exceptions.RallyAssertionError("Operation must run at least for one iteration.")
else:
self._total_iterations = None
self._it = None
def start(self):
self._it = 0
@property
def sample_type(self):
return metrics.SampleType.Warmup if self._it < self._warmup_iterations else metrics.SampleType.Normal
@property
def infinite(self):
return self._iterations is None
@property
def percent_completed(self):
return (self._it + 1) / self._total_iterations
@property
def completed(self):
return self._it >= self._total_iterations
def next(self):
self._it += 1
def __str__(self):
return "iteration-count-based"