benchmarking/run_lab.py (672 lines of code) (raw):

#!/usr/bin/env python ############################################################################## # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ############################################################################## from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import argparse import gc import json import logging import multiprocessing import os import shutil import signal import stat import tempfile import threading import time from concurrent.futures import ProcessPoolExecutor as Pool from io import StringIO from bridge.db import DBDriver from bridge.file_storages import UploadDownloadFiles from download_benchmarks.download_benchmarks import DownloadBenchmarks from harness import BenchmarkDriver from platforms.android.adb import ADB from platforms.device_manager import CoolDownDevice, DeviceManager from platforms.device_manager import DEFAULT_DM_INTERVAL as default_dm_interval from platforms.device_manager import MINIMUM_DM_INTERVAL as minimum_dm_interval from platforms.device_manager import getDevicesString from platforms.device_manager import valid_dm_interval from utils.check_argparse import claimer_id_type from utils.custom_logger import getLogger, setLoggerLevel from utils.log_update_handler import DBLogUpdateHandler from utils.log_utils import DEFAULT_INTERVAL as default_interval from utils.log_utils import trimLog, collectLogData, valid_interval from utils.utilities import DownloadException, BenchmarkArgParseException from utils.utilities import HARNESS_ERROR_FLAG as HARNESS_ERROR from utils.utilities import KILLED_FLAG as RUN_KILLED from utils.utilities import SUCCESS_FLAG as SUCCESS from utils.utilities import TIMEOUT_FLAG as RUN_TIMEOUT from utils.utilities import USER_ERROR_FLAG as USER_ERROR from utils.utilities import getFilename, getMachineId, setRunKilled from utils.watchdog import WatchDog parser = argparse.ArgumentParser(description="Run the benchmark remotely") parser.add_argument( "--android_dir", default="/data/local/tmp/", help="The directory in the android device all files are pushed to.", ) parser.add_argument( "--app_id", help="The app id you use to upload/download your file for everstore " "and access the job queue", ) parser.add_argument( "--claimer_id", default=getMachineId(), type=claimer_id_type, help="A unique claimer id to represent itself. " "Must talk to Caffe2 team to set it up.", ) parser.add_argument( "--cooldown", default=0, type=float, help="Specify the time interval between two test runs.", ) parser.add_argument( "-d", "--devices", help="Specify the devices to run the benchmark, in a comma separated " "list. The value is the device or device_hash field of the meta info.", ) parser.add_argument( "--job_queue", default="aibench_interactive", help="Specify the db job queue that the benchmark is sent to", ) parser.add_argument( "--logger_level", default="info", choices=["info", "warning", "error"], help="Specify the logger level", ) parser.add_argument( "--model_cache", required=True, help="The local directory containing the cached models. It should not " "be part of a git directory.", ) parser.add_argument( "--monsoon_map", help="Map the phone hash to the monsoon serial number." ) parser.add_argument( "-p", "--platform", required=True, help="Specify the platform to benchmark on. Use this flag if the framework" " needs special compilation scripts. The scripts are called build.sh " "saved in specifications/frameworks/<framework>/<platform> directory", ) parser.add_argument("--platform_sig", help="Specify the platform signature") parser.add_argument( "--reboot", action="store_true", help="Tries to reboot the devices before launching benchmarks for one " "commit.", ) parser.add_argument( "--remote_reporter", required=True, help="Save the result to a remote server. " "The style is <domain_name>/<endpoint>|<category>", ) parser.add_argument( "--remote_access_token", default="", help="The access token to access the remote server", ) parser.add_argument( "--root_model_dir", help="The root model directory if the meta data of the model uses " "relative directory, i.e. the location field starts with //", ) parser.add_argument( "--rt_logging", action="store_true", help="Enable realtime logging to database." ) parser.add_argument( "--rt_logging_interval", type=valid_interval, default=str(default_interval), help="Realtime logging Update interval in seconds. Minimum 5 seconds.", ) parser.add_argument( "--device_monitor_interval", type=valid_dm_interval, default=str(default_dm_interval), help="Device monitoring interval in seconds. Minimum {}s, default {}s.".format( minimum_dm_interval, default_dm_interval ), ) parser.add_argument( "--shared_libs", help="Pass the shared libs that the framework depends on, " "in a comma separated list.", ) parser.add_argument( "--status_file", help="A file to inform the driver stops running when the content of the file is 0.", ) parser.add_argument( "--test", action="store_true", help="Indicate whether this is a test run. Test runs use a different database.", ) parser.add_argument( "--timeout", default=300, type=float, help="Specify a timeout running the test on the platforms. " "The timeout value needs to be large enough so that the low end devices " "can safely finish the execution in normal conditions. Note, in A/B " "testing mode, the test runs twice. ", ) parser.add_argument( "--token", help="The token you use to upload/download your file for everstore " "and access the job queue", ) parser.add_argument( "--hash_platform_mapping", default=None, help="Specify the devices hash platform mapping json file.", ) parser.add_argument( "--device_name_mapping", default=None, help="Specify device to product name mapping json file.", ) parser.add_argument( "--usb_hub_device_mapping", default=None, help="Specify the usb hub hash, port mapping to devices", ) parser.add_argument( "--file_storage", help="The storage engine for uploading and downloading files" ) parser.add_argument("--benchmark_db_entry", help="The entry point of server's database") parser.add_argument("--server_addr", help="The lab's server address") parser.add_argument( "--benchmark_db", help="The database that will store benchmark infos" ) parser.add_argument( "--benchmark_table", help="The table that will store benchmark infos" ) LOCK = multiprocessing.Lock() DRAIN = False RUNNING_JOBS = 0 def drainHandler(signum, frame): global DRAIN DRAIN = True def hookSignals(): signal.signal(signal.SIGUSR1, drainHandler) signal.signal(signal.SIGTERM, drainHandler) signal.signal(signal.SIGINT, drainHandler) def stopRun(args): global DRAIN global RUNNING_JOBS if DRAIN and RUNNING_JOBS == 0: getLogger().info("Finished draining. Exiting...") return True if args.status_file and os.path.isfile(args.status_file): with open(args.status_file, "r") as file: content = file.read().strip() if content == "0": return True return False class runAsync(object): def __init__( self, args, device, db, job, benchmark_downloader, file_storage, usb_controller ): self.args = args self.device = device self.db = db self.job = job self.tempdir = tempfile.mkdtemp( prefix="_".join(["aibench", str(job.get("identifier")), ""]) ) self.benchmark_downloader = benchmark_downloader self.file_storage = file_storage self.usb_controller = usb_controller def __call__(self): return self.run() def run(self): # set env vars of this process and any subprocess for logging. os.environ["JOB_IDENTIFIER"] = str(self.job["identifier"]) os.environ["JOB_ID"] = str(self.job["id"]) handlers = [] log_capture_string = StringIO() ch = logging.StreamHandler(log_capture_string) ch.setLevel(logging.DEBUG) getLogger().addHandler(ch) handlers.append(ch) # if enabled realtime logger will also update the log entry at regualr intervals. if self.args.rt_logging: dbh = DBLogUpdateHandler( self.db, self.job["id"], self.args.rt_logging_interval ) dbh.setLevel(logging.DEBUG) getLogger().addHandler(dbh) handlers.append(dbh) getLogger().info( "Realtime logging enabled with {}s updates.".format( self.args.rt_logging_interval ) ) try: self._setFramework() with LOCK: getLogger().info( f"Lock acquired by {os.getpid()} before _downloadFiles() for benchmark {self.job['identifier']} id ({self.job['id']})" ) self._downloadFiles() raw_args = self._getRawArgs() app = BenchmarkDriver(raw_args=raw_args, usb_controller=self.usb_controller) getLogger().debug( f"Running BenchmarkDriver for benchmark {self.job['identifier']} id ({self.job['id']})" ) status = app.run() except DownloadException: getLogger().critical( f"An error occurred while downloading files for benchmark {self.job['identifier']} id ({self.job['id']}", exc_info=True, ) status = HARNESS_ERROR except BenchmarkArgParseException: getLogger().exception( f"An error occurred while parsing arguments for benchmark {self.job['identifier']} id ({self.job['id']})" ) status = USER_ERROR except Exception: getLogger().critical( f"An error occurred while running benchmark {self.job['identifier']} id ({self.job['id']}", exc_info=True, ) status = HARNESS_ERROR finally: output = log_capture_string.getvalue() log_capture_string.close() for handler in handlers: handler.close() getLogger().handlers.remove(handler) del handlers self._setStatusOutput(status, output) self._submitDone() self._removeBenchmarkFiles() time.sleep(1) return {"device": self.device, "job": self.job} def _setFramework(self): """Set framework of job based on benchmark config. Default to caffe2""" try: content = self.job["benchmarks"]["benchmark"]["content"] if content["tests"][0]["metric"] == "generic": self.job["framework"] = "generic" elif "model" in content and "framework" in content["model"]: self.job["framework"] = content["model"]["framework"] else: getLogger().warning( "Framework is not specified, using Caffe2 as default" ) self.job["framework"] = "caffe2" except Exception: raise BenchmarkArgParseException("Error parsing raw args from job.") def _getRawArgs(self): """Create raw args for BenchmarkDriver""" try: if "info" in self.job["benchmarks"]: info = self.job["benchmarks"]["info"] # pass the device hash as well as type device = {"kind": self.job["device"], "hash": self.job["hash"]} device_str = json.dumps(device) raw_args = [] raw_args.extend( [ "--benchmark_file", self.job["benchmarks"]["benchmark"]["content"], "--cooldown", str(self.args.cooldown), "--device", device_str, "--framework", self.job["framework"], "--info", json.dumps(info), "--model_cache", self.args.model_cache, "--platform", self.args.platform, "--remote_access_token", self.args.remote_access_token, "--root_model_dir", self.args.root_model_dir, "--simple_local_reporter", self.tempdir, "--user_identifier", str(self.job["identifier"]), "--user_string", self.job.get("user"), ] ) if self.job["framework"] != "generic": raw_args.extend(["--remote_reporter", self.args.remote_reporter]) if self.args.shared_libs: raw_args.extend(["--shared_libs", "'" + self.args.shared_libs + "'"]) if self.args.timeout: raw_args.extend(["--timeout", str(self.args.timeout)]) if self.args.platform_sig: raw_args.append("--platform_sig") raw_args.append(self.args.platform_sig) if self.args.monsoon_map: raw_args.extend(["--monsoon_map", str(self.args.monsoon_map)]) if self.args.hash_platform_mapping: # if the user provides filename, we will load it. raw_args.append("--hash_platform_mapping") raw_args.append(self.args.hash_platform_mapping) if self.args.device_name_mapping: raw_args.append("--device_name_mapping") raw_args.append(self.args.device_name_mapping) return raw_args except Exception: raise BenchmarkArgParseException("Error parsing raw args from job.") def _saveBenchmarks(self): """Save benchmark config to file""" content = self.job["benchmarks"]["benchmark"]["content"] benchmark_str = json.dumps(content) path = self.tempdir + "benchmark" with open(path, "w") as f: f.write(benchmark_str) self.job["benchmarks"]["benchmark"]["content"] = path return path def _downloadBinaries(self, info_dict): """Download benchmark binaries and return locations.""" programs = info_dict["programs"] program_locations = [] for bin_name in programs: program_location = programs[bin_name]["location"] self.benchmark_downloader.downloadFile(program_location, None) if program_location.startswith("//"): program_location = self.args.root_model_dir + program_location[1:] elif program_location.startswith("http"): replace_pattern = { " ": "-", "\\": "-", ":": "/", } program_location = os.path.join( self.args.root_model_dir, getFilename(program_location, replace_pattern=replace_pattern), ) elif program_location.startswith("/"): program_location = self.args.root_model_dir + program_location if ( self.args.platform.startswith("ios") and bin_name == "program" and not program_location.endswith(".ipa") ): new_location = program_location + ".ipa" os.rename(program_location, new_location) program_location = new_location os.chmod(program_location, stat.S_IXUSR | stat.S_IRUSR | stat.S_IWUSR) programs[bin_name]["location"] = program_location program_locations.append(program_location) return program_locations def _downloadFiles(self): """Download benchmark files.""" try: # Download models self.job["models_location"] = [] getLogger().info( f"Downloading models for benchmark {self.job['identifier']} id {self.job['id']}" ) path = self._saveBenchmarks() location = self.benchmark_downloader.run(path) self.job["models_location"].extend(location) # Download programs if "info" in self.job["benchmarks"]: getLogger().info( f"Downloading programs for benchmark {self.job['identifier']} id {self.job['id']}" ) if "treatment" not in self.job["benchmarks"]["info"]: getLogger().error( "Field treatment " 'must exist in job["benchmarks"]' ) elif "programs" not in self.job["benchmarks"]["info"]["treatment"]: getLogger().error( 'Field "program" must exist in ' 'job["benchmarks"]["info"]["treatment"]' ) else: treatment_info = self.job["benchmarks"]["info"]["treatment"] getLogger().info("Downloading treatment binary") treatment_locations = self._downloadBinaries(treatment_info) self.job["programs_location"] = treatment_locations if "control" in self.job["benchmarks"]["info"]: if "programs" not in self.job["benchmarks"]["info"]["control"]: getLogger().error( 'Field "program" must exist in ' 'job["benchmarks"]["info"]["control"]' ) else: control_info = self.job["benchmarks"]["info"]["control"] getLogger().info("Downloading control binary") control_locations = self._downloadBinaries(control_info) self.job["programs_location"].extend(control_locations) except Exception: raise DownloadException( f"Failed to download files for benchmark {self.job['identifier']} id {self.job['id']}" ) finally: gc.collect() def _setStatusOutput(self, status, output): """Set job status and log, to be submitted to db and returned to main process""" if status == RUN_KILLED: self.job["status"] = "KILLED" elif status == RUN_TIMEOUT: self.job["status"] = "TIMEOUT" elif status == SUCCESS: self.job["status"] = "DONE" elif status == USER_ERROR: self.job["status"] = "USER_ERROR" else: self.job["status"] = "FAILED" if not output: getLogger().critical("No output was captured.") else: output = trimLog(output) self.job["log"] = output def _submitDone(self): """Collect benchmark data and log, submit to db.""" data = self._collectBenchmarkData(self.tempdir) log = collectLogData(self.job) self.db.doneBenchmarks(str(self.job["id"]), self.job["status"], data, log) def _removeBenchmarkFiles(self): """Attempt to remove temporary files, ignore errors.""" shutil.rmtree(self.tempdir, True) # Do not remove programs or model files here. Files can be removed # by access time with a service. # programs_location = self.job.get("programs_location", []) # for program_location in programs_location: # shutil.rmtree(os.path.dirname(program_location), True) # We don't delete files from models_location because after each run # to save model re-download time. This might be a problem to have # many files saving on the disk and many disk spaces used. # models_location = self.job.get("models_location", []) # for model_location in models_location: # shutil.rmtree(os.path.dirname(model_location), True) def _collectBenchmarkData(self, output_dir): data = {} dirs = self._listdirs(output_dir) for d in dirs: f = os.path.join(*[output_dir, d, "data.txt"]) if not os.path.isfile(f): getLogger().critical( f"The output file {f} doesn't exist for benchmark." ) continue with open(f, "r") as file: content = json.load(file) # special case for power metrics if "power_data" in content: content["power_data"] = self._handlePowerData(content["power_data"]) data[d] = content return json.dumps(data) def _listdirs(self, path): return [x for x in os.listdir(path) if os.path.isdir(os.path.join(path, x))] def _handlePowerData(self, filename): if not os.path.isfile(filename): getLogger().critical(f"Power data file '{filename}' doesn't exist.") return file_link = self.file_storage.upload(file=filename, permanent=False) # remove the temporary file os.remove(filename) return file_link def didUserRequestJobKill(self): jobs = self.db.statusBenchmarks(self.job["identifier"]) for job in jobs: if job["status"] == "KILLED": return True return False def killJob(self): setRunKilled(True) class RunLab(object): def __init__(self, raw_args=None): self.args, self.unknowns = parser.parse_known_args(raw_args) os.environ["CLAIMER"] = self.args.claimer_id self.benchmark_downloader = DownloadBenchmarks(self.args, getLogger()) self.file_storage = UploadDownloadFiles(self.args) self.adb = ADB(None, self.args.android_dir) setLoggerLevel(self.args.logger_level) if not self.args.benchmark_db_entry: assert ( self.args.server_addr is not None ), "Either server_addr or benchmark_db_entry must be specified" while self.args.server_addr[-1] == "/": self.args.server_addr = self.args.server_addr[:-1] self.args.benchmark_db_entry = self.args.server_addr + "/benchmark/" self.db = DBDriver( self.args.benchmark_db, self.args.app_id, self.args.token, self.args.benchmark_table, self.args.job_queue, self.args.test, self.args.benchmark_db_entry, ) self.device_manager = DeviceManager(self.args, self.db) self.devices = self.device_manager.getLabDevices() if self.args.platform.startswith("host"): numProcesses = 2 else: numProcesses = multiprocessing.cpu_count() - 1 self.pool = Pool(max_workers=numProcesses, initializer=hookSignals) def run(self): hookSignals() while not stopRun(self.args): self._runOnce() time.sleep(1) self.db.updateDevices(self.args.claimer_id, "", True) self.device_manager.shutdown() def _runOnce(self): jobs = self._claimBenchmarks() jobs_queue, remaining_jobs = self._selectBenchmarks(jobs) if len(remaining_jobs) != 0: self._releaseBenchmarks(remaining_jobs) if len(jobs_queue) == 0: return self._runBenchmarks(jobs_queue) def _claimBenchmarks(self): claimer_id = self.args.claimer_id # get available devices with their hashes devices = [] hashes = [] for k in self.devices: for hash in self.devices[k]: if self.devices[k][hash]["available"]: devices.append(k) hashes.append(hash) hashes = ",".join(hashes) devices = ",".join(devices) jobs = [] if len(devices) > 0: jobs = self.db.claimBenchmarks(claimer_id, devices, hashes) return jobs def _selectBenchmarks(self, jobs): remaining_jobs = [] jobs_queue = [] for job in jobs: device_kind = job["device"] if device_kind not in self.devices: getLogger().error( "Retrieved job for device " f"{device_kind} cannot be run on server " f"{self.args.claimer_id}." ) remaining_jobs.append(job) elif ( job.get("hash") is not None and job.get("hash") not in self.devices[device_kind] ): getLogger().error( "Retrieved job for device with hash" f"{job.get('hash')} cannot be run on server " f"{self.args.claimer_id}." ) remaining_jobs.append(job) else: if job.get("hash") is not None: getLogger().info( f"Job {job['id']} requests device {job['device']} with hash {job['hash']}." ) hashes = [job["hash"]] else: hashes = self.devices[device_kind] for hash in hashes: device = self.devices[device_kind][hash] if device["available"] is True: getLogger().info( f"Device {job['device']} with hash {hash} available for job {job['id']}" ) job["hash"] = hash jobs_queue.append(job) device["available"] = False break else: getLogger().critical( f"The requested device {job['device']} with hash {job.get('hash', '<unspecified>')}is not available on this server." ) remaining_jobs.append(job) return jobs_queue, remaining_jobs def _releaseBenchmarks(self, remaining_jobs): # releasing unmatched jobs releasing_ids = ",".join([str(job["id"]) for job in remaining_jobs]) self.db.releaseBenchmarks(self.args.claimer_id, releasing_ids) def _runBenchmarks(self, jobs_queue): """Given a queue of jobs, update run statuses and device statuses in db, and spawn job processes.""" run_ids = ",".join([str(job["id"]) for job in jobs_queue]) self.db.runBenchmarks(self.args.claimer_id, run_ids) run_devices = [self.devices[job["device"]][job["hash"]] for job in jobs_queue] getLogger().info("Updating devices status") self.db.updateDevices( self.args.claimer_id, getDevicesString(run_devices), False ) # run the benchmarks for job in jobs_queue: getLogger().info( f"Running job with identifier {job['identifier']} and id {job['id']}" ) device = self.devices[job["device"]][job["hash"]] device["start_time"] = time.ctime() async_runner = runAsync( self.args, device, self.db, job, self.benchmark_downloader, self.file_storage, self.device_manager.usb_controller, ) # Watchdog will be used to kill currently running jobs # based on user requests app = WatchDog( async_runner, async_runner.didUserRequestJobKill, async_runner.killJob ) global RUNNING_JOBS RUNNING_JOBS += 1 """ Python's multiprocessing need to pickle things to sling them in different processes. However, bounded methods are not pickable, so the way it's doing it here doesn't work. Thus, I added __call__ method to the class we are passing into the apply_async method. Ref: https://stackoverflow.com/a/6975654 """ future = self.pool.submit(app) future.add_done_callback(self.callback) def callback(self, future_result_dict): """Decrement running jobs count, output job log, and start device cooldown.""" global RUNNING_JOBS RUNNING_JOBS -= 1 try: result = future_result_dict.result() job = result["job"] device = result["device"] device = self.devices[device["kind"]][device["hash"]] # output benchmark log in main thread. getLogger().info( "\n{}\n\nBenchmark:\t\t{}\nJob:\t\t\t{}\nDevice Kind:\t\t{}\nDevice Hash:\t\t{}\n{}\n\n{}".format( "#" * 80, job["identifier"], job["id"], device["kind"], device["hash"], job["log"], "#" * 80, ) ) except Exception: getLogger().critical( "Couldn't complete async job. Exception from subprocess:", exc_info=True ) self._coolDown(device, force_reboot=job["status"] != "DONE") def _coolDown(self, device, force_reboot=False): t = CoolDownDevice(device, self.args, self.db, force_reboot) t.start() if __name__ == "__main__": raw_args = None app = RunLab(raw_args=raw_args) app.run()