benchmarking/lab_driver.py (180 lines of code) (raw):
#!/usr/bin/env python
##############################################################################
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
##############################################################################
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import json
import os
from download_benchmarks.download_benchmarks import DownloadBenchmarks
from harness import BenchmarkDriver
from repo_driver import RepoDriver as OSS_RepoDriver
from run_lab import RunLab
from run_remote import RunRemote
from utils.custom_logger import getLogger, setLoggerLevel
from utils.log_utils import DEFAULT_INTERVAL as default_interval
from utils.log_utils import valid_interval
parser = argparse.ArgumentParser(description="Download models from dewey")
parser.add_argument(
"--app_id", help="The app id you use to upload/download your file for everstore"
)
parser.add_argument(
"-b",
"--benchmark_file",
help="Specify the json file for the benchmark or a number of benchmarks",
)
parser.add_argument(
"--lab", action="store_true", help="Indicate whether the run is lab run."
)
parser.add_argument(
"--logger_level",
default="info",
choices=["info", "warning", "error"],
help="Specify the logger level",
)
parser.add_argument(
"--rt_logging", action="store_true", help="Enable realtime logging to database."
)
parser.add_argument(
"--rt_logging_interval",
type=valid_interval,
default=str(default_interval),
help="Realtime logging Update interval in seconds. Minimum 5 seconds.",
)
parser.add_argument(
"--remote",
action="store_true",
help="Submit the job to remote devices to run the benchmark.",
)
parser.add_argument(
"--root_model_dir",
required=True,
help="The root model directory if the meta data of the model uses "
"relative directory, i.e. the location field starts with //",
)
parser.add_argument(
"--token", help="The token you use to upload/download your file for everstore"
)
parser.add_argument(
"-c", "--custom_binary", help="Specify the custom binary that you want to run."
)
parser.add_argument(
"--pre_built_binary",
help="Specify the pre_built_binary to bypass the building process.",
)
parser.add_argument(
"--user_string",
help="If set, use this instead of the $USER env variable as the user string.",
)
parser.add_argument(
"--buck_target", default="", help="The buck command to build the custom binary"
)
class LabDriver(object):
def __init__(self, raw_args=None):
self.args, self.unknowns = parser.parse_known_args(raw_args)
setLoggerLevel(self.args.logger_level)
def run(self):
if (
not self.args.lab
and not self.args.remote
and not "--adhoc" not in self.args
):
assert self.args.benchmark_file, "--benchmark_file (-b) must be specified"
if self.args.benchmark_file and not self.args.remote:
getLogger().info("Checking benchmark files to download")
dbench = DownloadBenchmarks(self.args, getLogger())
dbench.run(self.args.benchmark_file)
if self.args.remote:
unique_args = [
"--app_id",
self.args.app_id,
"--token",
self.args.token,
]
if self.args.benchmark_file:
unique_args.extend(
[
"--benchmark_file",
self.args.benchmark_file,
]
)
if self.args.pre_built_binary:
unique_args.extend(
[
"--pre_built_binary",
self.args.pre_built_binary,
]
)
if self.args.user_string:
unique_args.extend(
[
"--user_string",
self.args.user_string,
]
)
if self.args.buck_target:
unique_args.extend(
[
"--buck_target",
self.args.buck_target,
]
)
# hack to remove --repo from the argument list since python2
# argparse doesn't support allow_abbrev to be False, and it is
# the prefix of --repo_dir
if "--repo" in self.unknowns:
index = self.unknowns.index("--repo")
new_unknowns = self.unknowns[:index]
new_unknowns.extend(self.unknowns[index + 2 :])
self.unknowns = new_unknowns
app_class = RunRemote
elif self.args.lab:
unique_args = [
"--app_id",
self.args.app_id,
"--token",
self.args.token,
]
if self.args.rt_logging:
unique_args.extend(
[
"--rt_logging",
"--rt_logging_interval",
str(self.args.rt_logging_interval),
]
)
app_class = RunLab
elif self.args.custom_binary or self.args.pre_built_binary:
if self.args.custom_binary:
binary = self.args.custom_binary
else:
binary = self.args.pre_built_binary
repo_info = {
"treatment": {"program": binary, "commit": "-1", "commit_time": 0}
}
unique_args = [
"--info '",
json.dumps(repo_info) + "'",
"--benchmark_file",
self.args.benchmark_file,
]
app_class = BenchmarkDriver
else:
if self.args.user_string:
usr_string = self.args.user_string
else:
usr_string = os.environ["USER"]
unique_args = [
"--benchmark_file",
self.args.benchmark_file,
"--user_string",
usr_string,
]
app_class = OSS_RepoDriver
raw_args = []
raw_args.extend(unique_args)
raw_args.extend(["--root_model_dir", self.args.root_model_dir])
raw_args.extend(["--logger_level", self.args.logger_level])
raw_args.extend(self.unknowns)
getLogger().info("Running {} with raw_args {}".format(app_class, raw_args))
app = app_class(raw_args=raw_args)
res = app.run()
if res:
return res
if __name__ == "__main__":
raw_args = None
app = LabDriver(raw_args=raw_args)
app.run()