benchmarking/run_bench.py (233 lines of code) (raw):
#!/usr/bin/env python
##############################################################################
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
##############################################################################
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import copy
import json
import os
import six
from lab_driver import LabDriver
from utils.custom_logger import getLogger, setLoggerLevel
from utils.utilities import getString, getRunStatus, setRunStatus, unpackAdhocFile
HOME_DIR = os.path.expanduser("~")
parser = argparse.ArgumentParser(description="Perform one benchmark run")
parser.add_argument(
"--config_dir",
default=os.path.join(HOME_DIR, ".aibench", "git"),
help="Specify the config root directory.",
)
parser.add_argument(
"--logger_level",
default="info",
choices=["info", "warning", "error"],
help="Specify the logger level",
)
parser.add_argument(
"--reset_options",
action="store_true",
help="Reset all the options that is saved by default.",
)
class RunBench(object):
def __init__(self, raw_args=None):
self.args, self.unknowns = parser.parse_known_args(raw_args)
self.root_dir = self.args.config_dir
self.repoCls = LabDriver
setLoggerLevel(self.args.logger_level)
def run(self):
raw_args = self._getRawArgs()
if "--remote" in raw_args or "--lab" in raw_args:
# server address must start with http
assert "--server_addr" in raw_args
idx = raw_args.index("--server_addr")
assert raw_args[idx + 1].startswith("http") or len(raw_args[idx + 1]) == 0
if "--lab" in raw_args and "--remote_reporter" not in raw_args:
raw_args.extend(
[
"--remote_reporter",
raw_args[idx + 1]
+ ("" if raw_args[idx + 1][-1] == "/" else "/")
+ "benchmark/store-result|oss",
]
)
app = self.repoCls(raw_args=raw_args)
ret = app.run()
if "--query_num_devices" in self.unknowns:
return ret
if "--fetch_status" in self.unknowns or "--fetch_result" in self.unknowns:
return ret
if "--list_devices" in self.unknowns:
return ret
if ret is not None:
setRunStatus(ret >> 8)
return getRunStatus()
def _getUnknownArgs(self):
unknowns = self.unknowns
args = {}
i = 0
while i < len(unknowns):
if len(unknowns[i]) > 1 and unknowns[i][:1] == "-":
if i < len(unknowns) - 1 and unknowns[i + 1][:1] != "-":
args[unknowns[i]] = unknowns[i + 1]
i = i + 1
else:
args[unknowns[i]] = None
else:
# error conditionm, skipping
pass
i = i + 1
return args
def _saveDefaultArgs(self, new_args):
if not os.path.isdir(self.root_dir):
os.makedirs(self.root_dir)
print("Setting the default arguments...")
print(
"The default arguments are saved under {}".format(
self.root_dir + "/config.txt"
)
)
print("Alternatively, you can edit the config.txt file directly\n")
args = self._loadDefaultArgs()
config_file = os.path.join(self.root_dir, "config.txt")
if os.path.isfile(config_file):
with open(config_file, "r") as f:
load_args = json.load(f)
args.update(load_args)
args = self._askArgsFromUser(args, new_args)
if not os.path.isfile(args["--status_file"]):
with open(args["--status_file"], "w") as f:
f.write("1")
if "--screen_reporter" in args:
args["--screen_reporter"] = None
all_args = copy.deepcopy(args)
if "--benchmark_file" in args:
del args["--benchmark_file"]
if "-b" in args:
del args["-b"]
if "--devices" in args:
del args["--devices"]
with open(os.path.join(self.root_dir, "config.txt"), "w") as f:
json_args = json.dumps(args, indent=2, sort_keys=True)
f.write(json_args)
return all_args
def _askArgsFromUser(self, args, new_args):
args.update(new_args)
self._inputOneRequiredArg(
"Please enter the directory the framework repo resides", "--repo_dir", args
)
self._inputOneArg("Please enter the remote reporter", "--remote_reporter", args)
self._inputOneArg(
"Please enter the remote access token", "--remote_access_token", args
)
self._inputOneArg(
"Please enter the root model dir if needed", "--root_model_dir", args
)
self._inputOneArg(
"Do you want to print report to screen?", "--screen_reporter", args
)
return args
def _loadDefaultArgs(self):
args = {
"--benchmark_table": "benchmark_benchmarkinfo",
"--cache_config": os.path.join(self.root_dir, "cache_config.txt"),
"--remote_repository": "origin",
"--commit": "master",
"--commit_file": os.path.join(self.root_dir, "processed_commit"),
"--exec_dir": os.path.join(self.root_dir, "exec"),
"--framework": "caffe2",
"--local_reporter": os.path.join(self.root_dir, "reporter"),
"--repo": "git",
"--root_model_dir": os.path.join(self.root_dir, "root_model_dir"),
"--status_file": os.path.join(self.root_dir, "status"),
"--model_cache": os.path.join(self.root_dir, "model_cache"),
"--platform": "android",
"--file_storage": "django",
"--timeout": 300,
"--logger_level": "warning",
"--server_addr": "http://127.0.0.1:8000",
"--result_db": "django",
}
return args
def _inputOneArg(self, text, key, args):
arg = args[key] if key in args else None
v = six.moves.input(text + " [" + str(arg) + "]: ")
if v == "":
v = arg
if v is not None:
args[key] = v
return v
def _inputOneRequiredArg(self, text, key, args):
v = None
while v is None:
v = self._inputOneArg(text, key, args)
return v
def _getSavedArgs(self):
new_args = self._getUnknownArgs()
if (
self.args.reset_options
or not os.path.isdir(self.root_dir)
or not os.path.isfile(os.path.join(self.root_dir, "config.txt"))
):
args = self._saveDefaultArgs(new_args)
else:
args = {}
tiered_configs = ["config.txt", "config_overrides.txt"]
for config in tiered_configs:
config_file = os.path.join(self.root_dir, config)
if os.path.isfile(config_file):
with open(config_file, "r") as f:
args.update(json.load(f))
for v in new_args:
if v in args:
del args[v]
if "--lab" in new_args:
if "--remote" in args:
del args["--remote"]
return args
def _updateArgsWithBenchmarkOverrides(self, args):
unknowns = self._getUnknownArgs()
# Attempt to find benchmark_file flag in unknowns
# It might not be there depending on what type of
# run this is
benchmark_file = None
if "--benchmark_file" in unknowns:
benchmark_file = unknowns["--benchmark_file"]
if not benchmark_file and "-b" in unknowns:
benchmark_file = unknowns["-b"]
# Remove later when adhoc is moved to seperated infrastructure
if "--adhoc" in unknowns:
configName = unknowns["--adhoc"]
if configName is None:
configName = "generic"
adhoc_path, success = unpackAdhocFile(configName)
if success:
benchmark_file = adhoc_path
else:
getLogger().error(
"Could not find specified adhoc config: {}".format(configName)
)
if not benchmark_file:
return
# Try to load the benchmark_file and it's default_args section
if not os.path.isfile(benchmark_file):
return
benchmark = {}
with open(benchmark_file, "r") as f:
benchmark = json.load(f)
defaults = {}
if "default_args" in benchmark:
defaults = benchmark["default_args"]
if len(defaults) == 0:
return
# Remove args which are further overidden via cli flags
for arg in unknowns:
if arg in defaults:
del defaults[arg]
# Override args with default_overrides
args.update(defaults)
def _getRawArgs(self):
args = self._getSavedArgs()
self._updateArgsWithBenchmarkOverrides(args)
raw_args = []
for u in args:
raw_args.extend(
[getString(u), getString(args[u]) if args[u] is not None else ""]
)
raw_args.extend([getString(u) for u in self.unknowns])
raw_args.extend(["--logger_level", self.args.logger_level])
return raw_args
if __name__ == "__main__":
raw_args = None
app = RunBench(raw_args=raw_args)
app.run()