install.py (88 lines of code) (raw):
import argparse
import subprocess
import os
import sys
import tarfile
from install_utils import TORCH_DEPS, proxy_suggestion, get_pkg_versions, _test_https
def git_lfs_checkout():
tb_dir = os.path.dirname(os.path.realpath(__file__))
try:
# forcefully install git-lfs to the repo
subprocess.check_call(['git', 'lfs', 'install', '--force'], stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, cwd=tb_dir)
subprocess.check_call(['git', 'lfs', 'fetch'], stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, cwd=tb_dir)
subprocess.check_call(['git', 'lfs', 'checkout', '.'], stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, cwd=tb_dir)
except subprocess.CalledProcessError as e:
return (False, e.output)
except Exception as e:
return (False, e)
return True, None
def decompress_input():
tb_dir = os.path.dirname(os.path.realpath(__file__))
data_dir = os.path.join(tb_dir, "torchbenchmark", "data")
# Hide decompressed file in .data directory so that they won't be checked in
decompress_dir = os.path.join(data_dir, ".data")
os.makedirs(decompress_dir, exist_ok=True)
# Decompress every tar.gz file
for tarball in filter(lambda x: x.endswith(".tar.gz"), os.listdir(data_dir)):
tarball_path = os.path.join(data_dir, tarball)
print(f"decompressing input tarball: {tarball}...", end="", flush=True)
tar = tarfile.open(tarball_path)
tar.extractall(path=decompress_dir)
tar.close()
print("OK")
def pip_install_requirements():
if not _test_https():
print(proxy_suggestion)
sys.exit(-1)
try:
subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'],
check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
return (False, e.output)
except Exception as e:
return (False, e)
return True, None
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--continue_on_fail", action="store_true")
parser.add_argument("--models", nargs='+', default=[],
help="Specify one or more models to install. If not set, install all models.")
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
os.chdir(os.path.realpath(os.path.dirname(__file__)))
print(f"checking packages {', '.join(TORCH_DEPS)} are installed...", end="", flush=True)
try:
versions = get_pkg_versions(TORCH_DEPS)
except ModuleNotFoundError as e:
print("FAIL")
print(f"Error: Users must first manually install packages {TORCH_DEPS} before installing the benchmark.")
sys.exit(-1)
print("OK")
print("checking out Git LFS files...", end="", flush=True)
success, errmsg = git_lfs_checkout()
if success:
print("OK")
else:
print("FAIL")
print("Failed to checkout git lfs files. Please make sure you have installed git lfs.")
print(errmsg)
sys.exit(-1)
decompress_input()
success, errmsg = pip_install_requirements()
if not success:
print("Failed to install torchbenchmark requirements:")
print(errmsg)
if not args.continue_on_fail:
sys.exit(-1)
new_versions = get_pkg_versions(TORCH_DEPS)
if versions != new_versions:
print(f"The torch packages are re-installed after installing the benchmark deps. \
Before: {versions}, after: {new_versions}")
sys.exit(-1)
from torchbenchmark import setup
success &= setup(models=args.models, verbose=args.verbose, continue_on_fail=args.continue_on_fail)
if not success:
if args.continue_on_fail:
print("Warning: some benchmarks were not installed due to failure")
else:
raise RuntimeError("Failed to complete setup")