import argparse
import subprocess
import os
import sys
import tarfile
from install_utils import TORCH_DEPS, proxy_suggestion, get_pkg_versions, _test_https

def git_lfs_checkout():
    tb_dir = os.path.dirname(os.path.realpath(__file__))
    try:
        # forcefully install git-lfs to the repo
        subprocess.check_call(['git', 'lfs', 'install', '--force'], stdout=subprocess.PIPE,
                              stderr=subprocess.STDOUT, cwd=tb_dir)
        subprocess.check_call(['git', 'lfs', 'fetch'], stdout=subprocess.PIPE,
                              stderr=subprocess.STDOUT, cwd=tb_dir)
        subprocess.check_call(['git', 'lfs', 'checkout', '.'], stdout=subprocess.PIPE,
                              stderr=subprocess.STDOUT, cwd=tb_dir)
    except subprocess.CalledProcessError as e:
        return (False, e.output)
    except Exception as e:
        return (False, e)
    return True, None

def decompress_input():
    tb_dir = os.path.dirname(os.path.realpath(__file__))
    data_dir = os.path.join(tb_dir, "torchbenchmark", "data")
    # Hide decompressed file in .data directory so that they won't be checked in
    decompress_dir = os.path.join(data_dir, ".data")
    os.makedirs(decompress_dir, exist_ok=True)
    # Decompress every tar.gz file
    for tarball in filter(lambda x: x.endswith(".tar.gz"), os.listdir(data_dir)):
        tarball_path = os.path.join(data_dir, tarball)
        print(f"decompressing input tarball: {tarball}...", end="", flush=True)
        tar = tarfile.open(tarball_path)
        tar.extractall(path=decompress_dir)
        tar.close()
        print("OK")

def pip_install_requirements():
    if not _test_https():
        print(proxy_suggestion)
        sys.exit(-1)
    try:
        subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'],
                        check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    except subprocess.CalledProcessError as e:
        return (False, e.output)
    except Exception as e:
        return (False, e)
    return True, None

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--continue_on_fail", action="store_true")
    parser.add_argument("--models", nargs='+', default=[],
                        help="Specify one or more models to install. If not set, install all models.")
    parser.add_argument("--verbose", "-v", action="store_true")
    args = parser.parse_args()

    os.chdir(os.path.realpath(os.path.dirname(__file__)))

    print(f"checking packages {', '.join(TORCH_DEPS)} are installed...", end="", flush=True)
    try:
        versions = get_pkg_versions(TORCH_DEPS)
    except ModuleNotFoundError as e:
        print("FAIL")
        print(f"Error: Users must first manually install packages {TORCH_DEPS} before installing the benchmark.")
        sys.exit(-1)
    print("OK")

    print("checking out Git LFS files...", end="", flush=True)
    success, errmsg = git_lfs_checkout()
    if success:
        print("OK")
    else:
        print("FAIL")
        print("Failed to checkout git lfs files. Please make sure you have installed git lfs.")
        print(errmsg)
        sys.exit(-1)
    decompress_input()

    success, errmsg = pip_install_requirements()
    if not success:
        print("Failed to install torchbenchmark requirements:")
        print(errmsg)
        if not args.continue_on_fail:
            sys.exit(-1)
    new_versions = get_pkg_versions(TORCH_DEPS)
    if versions != new_versions:
        print(f"The torch packages are re-installed after installing the benchmark deps. \
                Before: {versions}, after: {new_versions}")
        sys.exit(-1)
    from torchbenchmark import setup
    success &= setup(models=args.models, verbose=args.verbose, continue_on_fail=args.continue_on_fail)
    if not success:
        if args.continue_on_fail:
            print("Warning: some benchmarks were not installed due to failure")
        else:
            raise RuntimeError("Failed to complete setup")
