benchmarking/download_benchmarks/download_benchmarks.py (130 lines of code) (raw):
#!/usr/bin/env python
##############################################################################
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
##############################################################################
from __future__ import absolute_import, division, print_function, unicode_literals
import gc
import hashlib
import os
from utils.custom_logger import getLogger
from utils.utilities import getBenchmarks, getFilename
from .download_file import DownloadFile
class DownloadBenchmarks(object):
def __init__(self, args, logger):
self.args = args
self.root_model_dir = self.args.root_model_dir
self.logger = logger
self.everstore = None
assert self.args.root_model_dir, "root_model_dir is not set"
def run(self, benchmark_file):
assert benchmark_file, "benchmark_file is not set"
benchmarks = getBenchmarks(benchmark_file)
locations = []
if not os.path.isdir(self.root_model_dir):
os.makedirs(self.root_model_dir)
for benchmark in benchmarks:
location = self._processOneBenchmark(benchmark)
locations.extend(location)
locations = [l for l in locations if l]
return locations
def _processOneBenchmark(self, benchmark):
filename = benchmark["filename"]
one_benchmark = benchmark["content"]
locations = []
# TODO refactor the code to collect files
if "model" in one_benchmark:
if "files" in one_benchmark["model"]:
for field in one_benchmark["model"]["files"]:
value = one_benchmark["model"]["files"][field]
assert (
"location" in value
), "location field is missing in benchmark " "{}".format(filename)
location = value["location"]
md5 = value.get("md5")
path = self.downloadFile(location, md5)
locations.append(path)
if "libraries" in one_benchmark["model"]:
for value in one_benchmark["model"]["libraries"]:
assert (
"location" in value
), "location field is missing in benchmark " "{}".format(filename)
location = value["location"]
md5 = value["md5"]
path = self.downloadFile(location, md5)
locations.append(path)
assert (
"tests" in one_benchmark
), "tests field is missing in benchmark {}".format(filename)
tests = one_benchmark["tests"]
for test in tests:
if "input_files" in test:
path = self._downloadTestFiles(test["input_files"])
locations.extend(path)
if "output_files" in test:
path = self._downloadTestFiles(test["output_files"])
locations.extend(path)
if "preprocess" in test and "files" in test["preprocess"]:
path = self._downloadTestFiles(test["preprocess"]["files"])
locations.extend(path)
if "postprocess" in test and "files" in test["postprocess"]:
path = self._downloadTestFiles(test["postprocess"]["files"])
locations.extend(path)
return locations
def downloadFile(self, location, md5):
if location.startswith("http"):
dirs = location.split(":/")
replace_pattern = {
" ": "-",
"\\": "-",
":": "/",
}
path = os.path.join(
self.root_model_dir,
getFilename(location, replace_pattern=replace_pattern),
)
elif not location.startswith("//"):
return
else:
dirs = location[2:].split("/")
if len(dirs) <= 2:
return
path = self.root_model_dir + location[1:]
if os.path.isfile(path):
if md5:
getLogger().info("Calculate md5 of {}".format(path))
file_hash = None
with open(path, "rb") as f:
file_hash = hashlib.md5()
for chunk in iter(lambda: f.read(8192), b""):
file_hash.update(chunk)
new_md5 = file_hash.hexdigest()
del file_hash
gc.collect()
if md5 == new_md5:
getLogger().info(
"File {}".format(os.path.basename(path))
+ " is cached, skip downloading"
)
return path
else:
# assume the file is the same
return path
downloader_controller = DownloadFile(
dirs=dirs, logger=self.logger, args=self.args
)
downloader_controller.download_file(location, path)
return path
def _downloadTestFiles(self, files):
locations = []
if isinstance(files, list):
for f in files:
if "location" in f:
path = self.downloadFile(f["location"], None)
locations.append(path)
elif isinstance(files, dict):
for f in files:
value = files[f]
if isinstance(value, list):
for v in value:
if "location" in v:
path = self.downloadFile(v["location"], None)
locations.append(path)
else:
if "location" in value:
path = self.downloadFile(value["location"], None)
locations.append(path)
return locations