benchmarking/lab_driver.py (180 lines of code) (raw):

#!/usr/bin/env python ############################################################################## # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ############################################################################## from __future__ import absolute_import, division, print_function, unicode_literals import argparse import json import os from download_benchmarks.download_benchmarks import DownloadBenchmarks from harness import BenchmarkDriver from repo_driver import RepoDriver as OSS_RepoDriver from run_lab import RunLab from run_remote import RunRemote from utils.custom_logger import getLogger, setLoggerLevel from utils.log_utils import DEFAULT_INTERVAL as default_interval from utils.log_utils import valid_interval parser = argparse.ArgumentParser(description="Download models from dewey") parser.add_argument( "--app_id", help="The app id you use to upload/download your file for everstore" ) parser.add_argument( "-b", "--benchmark_file", help="Specify the json file for the benchmark or a number of benchmarks", ) parser.add_argument( "--lab", action="store_true", help="Indicate whether the run is lab run." ) parser.add_argument( "--logger_level", default="info", choices=["info", "warning", "error"], help="Specify the logger level", ) parser.add_argument( "--rt_logging", action="store_true", help="Enable realtime logging to database." ) parser.add_argument( "--rt_logging_interval", type=valid_interval, default=str(default_interval), help="Realtime logging Update interval in seconds. Minimum 5 seconds.", ) parser.add_argument( "--remote", action="store_true", help="Submit the job to remote devices to run the benchmark.", ) parser.add_argument( "--root_model_dir", required=True, help="The root model directory if the meta data of the model uses " "relative directory, i.e. the location field starts with //", ) parser.add_argument( "--token", help="The token you use to upload/download your file for everstore" ) parser.add_argument( "-c", "--custom_binary", help="Specify the custom binary that you want to run." ) parser.add_argument( "--pre_built_binary", help="Specify the pre_built_binary to bypass the building process.", ) parser.add_argument( "--user_string", help="If set, use this instead of the $USER env variable as the user string.", ) parser.add_argument( "--buck_target", default="", help="The buck command to build the custom binary" ) class LabDriver(object): def __init__(self, raw_args=None): self.args, self.unknowns = parser.parse_known_args(raw_args) setLoggerLevel(self.args.logger_level) def run(self): if ( not self.args.lab and not self.args.remote and not "--adhoc" not in self.args ): assert self.args.benchmark_file, "--benchmark_file (-b) must be specified" if self.args.benchmark_file and not self.args.remote: getLogger().info("Checking benchmark files to download") dbench = DownloadBenchmarks(self.args, getLogger()) dbench.run(self.args.benchmark_file) if self.args.remote: unique_args = [ "--app_id", self.args.app_id, "--token", self.args.token, ] if self.args.benchmark_file: unique_args.extend( [ "--benchmark_file", self.args.benchmark_file, ] ) if self.args.pre_built_binary: unique_args.extend( [ "--pre_built_binary", self.args.pre_built_binary, ] ) if self.args.user_string: unique_args.extend( [ "--user_string", self.args.user_string, ] ) if self.args.buck_target: unique_args.extend( [ "--buck_target", self.args.buck_target, ] ) # hack to remove --repo from the argument list since python2 # argparse doesn't support allow_abbrev to be False, and it is # the prefix of --repo_dir if "--repo" in self.unknowns: index = self.unknowns.index("--repo") new_unknowns = self.unknowns[:index] new_unknowns.extend(self.unknowns[index + 2 :]) self.unknowns = new_unknowns app_class = RunRemote elif self.args.lab: unique_args = [ "--app_id", self.args.app_id, "--token", self.args.token, ] if self.args.rt_logging: unique_args.extend( [ "--rt_logging", "--rt_logging_interval", str(self.args.rt_logging_interval), ] ) app_class = RunLab elif self.args.custom_binary or self.args.pre_built_binary: if self.args.custom_binary: binary = self.args.custom_binary else: binary = self.args.pre_built_binary repo_info = { "treatment": {"program": binary, "commit": "-1", "commit_time": 0} } unique_args = [ "--info '", json.dumps(repo_info) + "'", "--benchmark_file", self.args.benchmark_file, ] app_class = BenchmarkDriver else: if self.args.user_string: usr_string = self.args.user_string else: usr_string = os.environ["USER"] unique_args = [ "--benchmark_file", self.args.benchmark_file, "--user_string", usr_string, ] app_class = OSS_RepoDriver raw_args = [] raw_args.extend(unique_args) raw_args.extend(["--root_model_dir", self.args.root_model_dir]) raw_args.extend(["--logger_level", self.args.logger_level]) raw_args.extend(self.unknowns) getLogger().info("Running {} with raw_args {}".format(app_class, raw_args)) app = app_class(raw_args=raw_args) res = app.run() if res: return res if __name__ == "__main__": raw_args = None app = LabDriver(raw_args=raw_args) app.run()