perfkitbenchmarker/scripts/throughput_load_driver.py (222 lines of code) (raw):

# Copyright 2024 PerfKitBenchmarker Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Load driver which runs commands in parallel. Intentionally has no dependencies on PKB so that it can run on a client VM, but is also imported by PKB to share flags, testing, & convenience functionality / constants. """ import dataclasses import json import logging import math import multiprocessing import os import subprocess import time from absl import app from absl import flags _PARALLEL_REQUESTS = flags.DEFINE_integer( '_ai_throughput_parallel_requests', 5, 'Number of requests to send in parallel at beginning of test. Only used by' ' the client VM script.', ) _REQUEST_COMMAND = flags.DEFINE_string( '_ai_throughput_command', '', 'Command to run for each request. Only used by the client VM script.', ) TEST_DURATION = flags.DEFINE_integer( 'ai_test_duration', 60, 'Number of seconds over which requests are sent. Used for both client VM &' ' overall PKB.', ) BURST_TIME = flags.DEFINE_float( 'ai_burst_time', 1.0, 'Number of seconds between each burst of requests. Used for both client VM' ' & overall PKB.', ) THROW_ON_CLIENT_ERRORS = flags.DEFINE_bool( 'ai_throw_on_client_errors', False, 'Whether to throw an exception if the client is not powerful enough to' ' send the desired QPS. Used for both client VM & overall PKB.', ) # Sagemaker times out requests if they take longer than 95 seconds. _FAIL_LATENCY = 95 _QUEUE_WAIT_TIME = _FAIL_LATENCY * 2 @dataclasses.dataclass class CommandResponse: """A response from the command + how long it took.""" start_time: float end_time: float response: str | None = None error: str | None = None status: int = 0 class ClientError(Exception): """An error with the client sending requests.""" def main(argv): """Sends the load with command line flags & writes results to a file.""" del argv start_time = time.time() Run() logging.info( 'Took %s seconds to Run & write responses. throughput_load_driver is' ' done.', time.time() - start_time, ) # Exit even if some processes are still running. os._exit(0) def Run() -> list[CommandResponse]: """Sends the load with command line flags & writes results to a file.""" responses = BurstRequestsOverTime( _REQUEST_COMMAND.value, _PARALLEL_REQUESTS.value, TEST_DURATION.value, BURST_TIME.value, ) file_path = GetOutputFilePath(_PARALLEL_REQUESTS.value) logging.info('Writing %s responses to %s', len(responses), file_path) responses_dicts = [dataclasses.asdict(response) for response in responses] with open(file_path, 'w') as f: json.dump({'responses': responses_dicts}, f) return responses def GetOutputFilePath(burst_requests: int) -> str: """Returns the output file path for the given burst requests.""" return f'/tmp/throughput_results_{burst_requests}.json' def ReadJsonResponses(burst_requests: int) -> list[CommandResponse]: """Reads the json responses from the file.""" with open(GetOutputFilePath(burst_requests), 'r') as f: loaded_json = json.load(f) responses_dicts = loaded_json['responses'] responses = [ CommandResponse(**response_dict) for response_dict in responses_dicts ] return responses def GetOverallTimeout() -> float: """Returns an overall timeout for the throughput operation.""" return TEST_DURATION.value + _QUEUE_WAIT_TIME * 2 def GetExpectedNumberResponses( burst_requests: int, total_duration: int, time_between_bursts: float, ) -> int: """Returns the expected number of responses for the given parameters.""" return math.floor(total_duration / time_between_bursts) * burst_requests def BurstRequestsOverTime( command: str, burst_requests: int, total_duration: int, time_between_bursts: float = 1.0, ) -> list[CommandResponse]: """Sends X requests to the model in parallel over total_duration seconds.""" start_time = time.time() goal_bursts = math.floor(total_duration / time_between_bursts) logging.info( 'Starting to send %s requests every %s seconds over %s duration %s times', burst_requests, time_between_bursts, total_duration, goal_bursts, ) output_queue = multiprocessing.Queue() processes = [] for _ in range(goal_bursts): process_start_time = time.time() processes += _SendParallelRequests(command, burst_requests, output_queue) process_startup_duration = time.time() - process_start_time if process_startup_duration > time_between_bursts: elapsed_time = time.time() - start_time _EncounterClientError( f'After running for {elapsed_time} seconds, the client took' f' {process_startup_duration} seconds to send' f' {burst_requests} requests, which is more than the' f' {time_between_bursts} seconds needed to meet QPS. This means the' ' client is not powerful enough & client with more CPUs should be' ' used.' ) # Wait to send next burst. while time.time() - process_start_time < time_between_bursts: time.sleep(0.1) results = _EmptyQueue(output_queue) _WaitForProcesses(processes) results = results + _EmptyQueue(output_queue) logging.info('Dumping all %s response results: %s', len(results), results) if results: logging.info('Logging one full response: %s', results[0]) expected_results = goal_bursts * burst_requests if len(results) < expected_results: logging.info( 'Theoretically started %s results but only got %s responses.' ' Exact reason is unknown, but this is not entirely unexpected.', expected_results, len(results), ) return results def _EmptyQueue(output_queue: multiprocessing.Queue) -> list[CommandResponse]: """Empties the queue, with a timeout & returns the results.""" logging.info('Waiting for all queued results') results = [] queue_start_time = time.time() queue_duration = 0 while not output_queue.empty(): results.append(output_queue.get()) queue_duration = time.time() - queue_start_time if queue_duration > _QUEUE_WAIT_TIME: _EncounterClientError( f'Waited more than {_QUEUE_WAIT_TIME} seconds for the queue to' ' empty. Exiting, but some data may have been dropped. Collected' f' {len(results)} results in the meantime', ) break logging.info( 'All %s queue results collected in: %s.', len(results), queue_duration, ) return results def _WaitForProcesses(processes: list[multiprocessing.Process]): """Waits for processes to finish, terminating any after waiting too long.""" process_start_time = time.time() process_duration = 0 original_process_count = len(processes) num_joined = 0 while processes: process = processes.pop() if process_duration > _QUEUE_WAIT_TIME: process.terminate() else: process.join(_FAIL_LATENCY) num_joined += 1 process_duration = time.time() - process_start_time if process_duration > _QUEUE_WAIT_TIME: _EncounterClientError( f'Waited more than {_QUEUE_WAIT_TIME} seconds for processes to join.' ' Exiting, but some data may have been dropped. Collected' f' {num_joined} out of' f' {original_process_count} total processes with join' ) logging.info( 'All %s processes finished joining or terimnated in %s seconds.', original_process_count, process_duration, ) def _SendParallelRequests( command: str, requests: int, output_queue: multiprocessing.Queue, ) -> list[multiprocessing.Process]: """Sends X requests to the model in parallel.""" logging.info('Sending %s requests in parallel', requests) processes = [] for _ in range(requests): p = multiprocessing.Process( target=_TimeCommand, args=(command, output_queue) ) processes.append(p) p.start() _UnitTestIdleTime() return processes def _UnitTestIdleTime(): """Sleeps in unit test.""" pass def _EncounterClientError(error_msg): """Throws or logs a client error.""" if THROW_ON_CLIENT_ERRORS.value: raise ClientError(error_msg) logging.warning(error_msg) def _TimeCommand( command: str, output_queue: multiprocessing.Queue, ): """Times the command & stores length + output in the queue.""" start_time = time.time() response, err, status = _RunCommand(command) end_time = time.time() output_queue.put(CommandResponse(start_time, end_time, response, err, status)) def _RunCommand( command: str, ) -> tuple[str, str, int]: """Runs a command and returns stdout, stderr, and return code.""" result = subprocess.run( command.split(' '), check=False, capture_output=True, text=True ) return result.stdout, result.stderr, result.returncode if __name__ == '__main__': app.run(main)