benchmarking/benchmarks/benchmarks.py (258 lines of code) (raw):

#!/usr/bin/env python ############################################################################## # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ############################################################################## from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import copy import gc import hashlib import json import os import shutil import tempfile import requests from utils.custom_logger import getLogger from utils.utilities import deepMerge, deepReplace COPY_THRESHOLD = 6442450944 # 6 GB class BenchmarkCollector(object): def __init__(self, framework, model_cache, **kwargs): self.args = kwargs.get("args", None) if not os.path.isdir(model_cache): os.makedirs(model_cache) self.model_cache = model_cache self.framework = framework def collectBenchmarks(self, info, source, user_identifier): assert os.path.isfile(source), "Source {} is not a file".format(source) with open(source, "r") as f: content = json.load(f) meta = content["meta"] if "meta" in content else {} if "meta" in info: deepMerge(meta, info["meta"]) if hasattr(self.args, "timeout"): meta["timeout"] = self.args.timeout benchmarks = [] if "benchmarks" in content: path = os.path.abspath(os.path.dirname(source)) assert "meta" in content, "Meta field is missing in benchmarks" for benchmark_file in content["benchmarks"]: benchmark_file = os.path.join(path, benchmark_file) self._collectOneBenchmark( benchmark_file, meta, benchmarks, info, user_identifier ) else: self._collectOneBenchmark(source, meta, benchmarks, info, user_identifier) for b in benchmarks: self._verifyBenchmark(b, b["path"], True) return benchmarks def _verifyBenchmark(self, benchmark, filename, is_post): self.framework.verifyBenchmarkFile(benchmark, filename, is_post) def _collectOneBenchmark(self, source, meta, benchmarks, info, user_identifier): assert os.path.isfile(source), "Benchmark {} does not exist".format(source) with open(source, "r") as b: one_benchmark = json.load(b) string_map = json.loads(self.args.string_map) if self.args.string_map else {} for name in string_map: value = string_map[name] deepReplace(one_benchmark, "{" + name + "}", value) self._verifyBenchmark(one_benchmark, source, False) self._updateFiles(one_benchmark, source, user_identifier) # following change should not appear in updated_json file if meta: deepMerge(one_benchmark["model"], meta) self._updateTests(one_benchmark, source) # Add fields that should not appear in the saved benchmark file # Adding path to benchmark file one_benchmark["path"] = os.path.abspath(source) # One test per benchmark if len(one_benchmark["tests"]) == 1: benchmarks.append(one_benchmark) else: tests = copy.deepcopy(one_benchmark["tests"]) one_benchmark["tests"] = [] for test in tests: new_benchmark = copy.deepcopy(one_benchmark) new_benchmark["tests"].append(test) benchmarks.append(new_benchmark) # Update all files in the benchmark to absolute path # download the files if needed def _updateFiles(self, one_benchmark, filename, user_identifier): model = one_benchmark["model"] model_dir = os.path.join(self.model_cache, model["format"], model["name"]) if not os.path.isdir(model_dir): os.makedirs(model_dir) collected_files, collected_tmp_files = self._collectFiles(one_benchmark) update_json = False for file in collected_files: update_json |= self._updateOneFile(file, model_dir, filename) if update_json: s = json.dumps(one_benchmark, indent=2, sort_keys=True) with open(filename, "w") as f: f.write(s) getLogger().info( "Model {} is changed. ".format(model["name"]) + "Please update the meta json file." ) # update the file field with the absolute path # needs to be after the file is updated # only update files with md5, which means those files are not # temporary files for file in collected_files: if "md5" in file: cached_filename = self._getDestFilename(file, model_dir) file["location"] = cached_filename tmp_dir = tempfile.mkdtemp( prefix="_".join(["aibench", str(user_identifier), ""]) ) for tmp_file in collected_tmp_files: tmp_file["location"] = tmp_file["location"].replace("{TEMPDIR}", tmp_dir) def _collectFiles(self, benchmark): files = [] tmp_files = [] if "model" in benchmark: if "files" in benchmark["model"]: self._collectOneGroupFiles(benchmark["model"]["files"], files) if "libraries" in benchmark["model"]: self._collectOneGroupFiles(benchmark["model"]["libraries"], files) for test in benchmark["tests"]: if "files" in test: self._collectOneGroupFiles(test["files"], files, tmp_files) if "input_files" in test: self._collectOneGroupFiles(test["input_files"], files) if "output_files" in test: self._collectOneGroupFiles(test["output_files"], files, tmp_files) if "preprocess" in test and "files" in test["preprocess"]: self._collectOneGroupFiles( test["preprocess"]["files"], files, tmp_files ) if "postprocess" in test and "files" in test["postprocess"]: self._collectOneGroupFiles( test["postprocess"]["files"], files, tmp_files ) return files, tmp_files def _collectOneGroupFiles(self, group, files, tmp_files=None): if isinstance(group, list): for f in group: self._collectOneFile(f, files, tmp_files) elif isinstance(group, dict): for name in group: f_or_list = group[name] if isinstance(f_or_list, list): for f in f_or_list: self._collectOneFile(f, files, tmp_files) else: self._collectOneFile(f_or_list, files, tmp_files) def _collectOneFile(self, item, files, tmp_files): if "location" in item and "{TEMPDIR}" in item["location"]: assert tmp_files is not None, "tmp file can only exist for output" tmp_files.append(item) return assert "filename" in item, "field filename must exist" if "location" not in item: return files.append(item) def _updateOneFile(self, field, model_dir, filename): cached_filename = self._getDestFilename(field, model_dir) if ( "md5" in field and field["md5"] is not None and ( not os.path.isfile(cached_filename) or self._calculateMD5(cached_filename, field["md5"], filename) != field["md5"] ) ): return self._copyFile(field, cached_filename, filename) return False def _calculateMD5(self, model_name, old_md5, filename): if os.stat(filename).st_size >= COPY_THRESHOLD or os.path.islink(model_name): if not os.path.isfile(model_name): getLogger().info( "Create symlink between {} and {}".format(filename, model_name) ) os.symlink(filename, model_name) return old_md5 getLogger().info("Calculate md5 of {}".format(model_name)) with open(model_name, "rb") as f: file_hash = hashlib.md5() for chunk in iter(lambda: f.read(8192), b""): file_hash.update(chunk) md5 = file_hash.hexdigest() del file_hash gc.collect() return md5 def _copyFile(self, field, destination_name, source): if "location" not in field: return False location = field["location"] if location[0:4] == "http": abs_name = destination_name getLogger().info("Downloading {}".format(location)) r = requests.get(location) if r.status_code == 200: with open(destination_name, "wb") as f: f.write(r.content) else: abs_name = self._getAbsFilename(field, source, None) if os.path.isfile(abs_name): if os.stat(abs_name).st_size < COPY_THRESHOLD: shutil.copyfile(abs_name, destination_name) else: if not os.path.isfile(destination_name): getLogger().info( "Create symlink between {} and {}".format( abs_name, destination_name ) ) os.symlink(abs_name, destination_name) else: import distutils.dir_util distutils.dir_util.copy_tree(abs_name, destination_name) if os.path.isdir(destination_name) and field["md5"] == "directory": return False assert os.path.isfile(destination_name), "File {} cannot be retrieved".format( destination_name ) # verify the md5 matches the file downloaded md5 = self._calculateMD5(destination_name, field["md5"], abs_name) if md5 != field["md5"]: getLogger().info( "Source file {} is changed, ".format(location) + " updating MD5. " + "Please commit the updated json file." ) field["md5"] = md5 return True return False def _getDestFilename(self, field, dir): fn = os.path.splitext(field["filename"]) cached_name = os.path.join(dir, fn[0] + fn[1]) return cached_name def _updateTests(self, one_benchmark, source): if one_benchmark["tests"][0]["metric"] == "generic": return # framework specific updates self.framework.rewriteBenchmarkTests(one_benchmark, source) # rewrite test fields for compatibility reasons for test in one_benchmark["tests"]: self._rewriteTestFields(test) # Update identifiers, the last update self._updateNewTestFields(one_benchmark["tests"], one_benchmark) def _rewriteTestFields(self, test): if "arguments" in test: assert ( "commands" not in test ), "Commands and arguments cannot co-exist in test" test["commands"] = ["{program} " + test["arguments"]] del test["arguments"] if "command" in test: assert ( "commands" not in test ), "Commands and command cannot co-exist in test" test["commands"] = [test["command"]] # do not delete for now # del test["command"] def _updateNewTestFields(self, tests, one_benchmark): idx = 0 for test in tests: identifier = test["identifier"].replace("{ID}", str(idx)) test["identifier"] = identifier idx += 1 def _getAbsFilename(self, file, source, cache_dir): location = file["location"] filename = file["filename"] if location[0:4] == "http": # Need to download, return the destination filename return os.path.join(cache_dir, filename) elif location[0:2] == "//": assert self.args.root_model_dir is not None, ( "When specifying relative directory, the " "--root_model_dir must be specified." ) return self.args.root_model_dir + location[1:] elif location[0] != "/": abs_dir = os.path.dirname(os.path.abspath(source)) return os.path.join(abs_dir, location) else: return location