benchmarking/download_benchmarks/download_benchmarks.py (130 lines of code) (raw):

#!/usr/bin/env python ############################################################################## # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ############################################################################## from __future__ import absolute_import, division, print_function, unicode_literals import gc import hashlib import os from utils.custom_logger import getLogger from utils.utilities import getBenchmarks, getFilename from .download_file import DownloadFile class DownloadBenchmarks(object): def __init__(self, args, logger): self.args = args self.root_model_dir = self.args.root_model_dir self.logger = logger self.everstore = None assert self.args.root_model_dir, "root_model_dir is not set" def run(self, benchmark_file): assert benchmark_file, "benchmark_file is not set" benchmarks = getBenchmarks(benchmark_file) locations = [] if not os.path.isdir(self.root_model_dir): os.makedirs(self.root_model_dir) for benchmark in benchmarks: location = self._processOneBenchmark(benchmark) locations.extend(location) locations = [l for l in locations if l] return locations def _processOneBenchmark(self, benchmark): filename = benchmark["filename"] one_benchmark = benchmark["content"] locations = [] # TODO refactor the code to collect files if "model" in one_benchmark: if "files" in one_benchmark["model"]: for field in one_benchmark["model"]["files"]: value = one_benchmark["model"]["files"][field] assert ( "location" in value ), "location field is missing in benchmark " "{}".format(filename) location = value["location"] md5 = value.get("md5") path = self.downloadFile(location, md5) locations.append(path) if "libraries" in one_benchmark["model"]: for value in one_benchmark["model"]["libraries"]: assert ( "location" in value ), "location field is missing in benchmark " "{}".format(filename) location = value["location"] md5 = value["md5"] path = self.downloadFile(location, md5) locations.append(path) assert ( "tests" in one_benchmark ), "tests field is missing in benchmark {}".format(filename) tests = one_benchmark["tests"] for test in tests: if "input_files" in test: path = self._downloadTestFiles(test["input_files"]) locations.extend(path) if "output_files" in test: path = self._downloadTestFiles(test["output_files"]) locations.extend(path) if "preprocess" in test and "files" in test["preprocess"]: path = self._downloadTestFiles(test["preprocess"]["files"]) locations.extend(path) if "postprocess" in test and "files" in test["postprocess"]: path = self._downloadTestFiles(test["postprocess"]["files"]) locations.extend(path) return locations def downloadFile(self, location, md5): if location.startswith("http"): dirs = location.split(":/") replace_pattern = { " ": "-", "\\": "-", ":": "/", } path = os.path.join( self.root_model_dir, getFilename(location, replace_pattern=replace_pattern), ) elif not location.startswith("//"): return else: dirs = location[2:].split("/") if len(dirs) <= 2: return path = self.root_model_dir + location[1:] if os.path.isfile(path): if md5: getLogger().info("Calculate md5 of {}".format(path)) file_hash = None with open(path, "rb") as f: file_hash = hashlib.md5() for chunk in iter(lambda: f.read(8192), b""): file_hash.update(chunk) new_md5 = file_hash.hexdigest() del file_hash gc.collect() if md5 == new_md5: getLogger().info( "File {}".format(os.path.basename(path)) + " is cached, skip downloading" ) return path else: # assume the file is the same return path downloader_controller = DownloadFile( dirs=dirs, logger=self.logger, args=self.args ) downloader_controller.download_file(location, path) return path def _downloadTestFiles(self, files): locations = [] if isinstance(files, list): for f in files: if "location" in f: path = self.downloadFile(f["location"], None) locations.append(path) elif isinstance(files, dict): for f in files: value = files[f] if isinstance(value, list): for v in value: if "location" in v: path = self.downloadFile(v["location"], None) locations.append(path) else: if "location" in value: path = self.downloadFile(value["location"], None) locations.append(path) return locations