# Copyright (c) 2020 The Neuropod Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Installs the appropriate pip packages depending on the following env variables
# NEUROPOD_IS_GPU
# NEUROPOD_TORCH_VERSION
# NEUROPOD_TENSORFLOW_VERSION
import os
import platform
import subprocess
import sys

# The `or` pattern below handles empty strings and unset env variables
# Using a default value only handles unset env variables
REQUESTED_TF_VERSION = os.getenv("NEUROPOD_TENSORFLOW_VERSION") or "1.12.0"
REQUESTED_TORCH_VERSION = os.getenv("NEUROPOD_TORCH_VERSION") or "1.1.0"
IS_GPU = (os.getenv("NEUROPOD_IS_GPU") or None) is not None
CUDA_VERSION = os.getenv("NEUROPOD_CUDA_VERSION") or "10.0"
IS_MAC = platform.system() == "Darwin"

def pip_install(args):
    cmd = [sys.executable, "-m", "pip", "install"] + args
    print("Running pip command: {}".format(cmd))
    subprocess.check_call(cmd)

def install_pytorch(version):
    """
    :param  version:    The version of torch. This can be something like "1.2.0" or
                        "1.1.0.dev20190601"
    """
    pip_args = []

    # Get the torch cuda string (e.g. cpu, cu90, cu92, cu100)
    torch_cuda_string = "cu{}".format(CUDA_VERSION.replace(".", "")) if IS_GPU else "cpu"

    # TODO(vip): Fix this once we have a better way of dealing with CUDA 11.2
    if torch_cuda_string == "cu1121":
        torch_cuda_string = "cpu"

    # The base version of torch (e.g. 1.2.0)
    version_base = None

    # If this is a nightly build, what's the date (e.g. 20190809)
    version_date = None

    # Get the version info
    if "dev" in version:
        version_base, version_date = version.split(".dev")
    else:
        version_base = version

    if version_date != None:
        # This is a nightly build
        pip_args += ["-f", "https://download.pytorch.org/whl/nightly/" + torch_cuda_string + "/torch_nightly.html"]
    else:
        # This is a stable build
        pip_args += ["-f", "https://download.pytorch.org/whl/torch_stable.html"]

    # Mac builds do not have the cuda string as part of the version
    if not IS_MAC:
        # If this is the 1.2.0 stable release or it's a nightly build after they started adding the cuda string to the packages
        if (version_base == "1.2.0" and version_date is None) or (version_date != None and int(version_date) > 20190723):
            # For CUDA 10 builds, they don't add `cu100` to the version string
            if torch_cuda_string != "cu100":
                version += "+" + torch_cuda_string

        # If this is the 1.3.0 or 1.4.0 stable release
        if (version_base == "1.3.0" or version_base == "1.4.0") and version_date is None:
            # They changed the default from cuda 10.0 to cuda 10.1
            # For CUDA 10.1 builds, they don't add `cu101` to the version string
            if torch_cuda_string != "cu101":
                version += "+" + torch_cuda_string

        # If this is the 1.5.0, 1.6.0, or 1.7.0 stable release
        if version_base in ["1.5.0", "1.6.0", "1.7.0"] and version_date is None:
            # They changed the default from cuda 10.1 to cuda 10.2
            # For CUDA 10.2 builds, they don't add `cu102` to the version string
            if torch_cuda_string != "cu102":
                version += "+" + torch_cuda_string

        # For 1.8.1 and 1.9.0, they always include the cuda version in the version string
        if version_base in ["1.8.1", "1.9.0", "1.10.2"] and version_date is None:
            version += "+" + torch_cuda_string

    # The Mac 1.3.0 stable release doesn't exist in `torch_stable.html`
    # Use 1.3.0.post2 instead
    if IS_MAC and version_base == "1.3.0" and version_date is None:
        version = "1.3.0.post2"

    if version_date != None:
        if int(version_date) >= 20190802:
            pip_args += ["torch==" + version]
        else:
            pip_args += ["torch_nightly==" + version]
    else:
        if IS_GPU and (version_base == "1.1.0" or version_base == "1.4.0" or version_base == "1.5.0"):
            # See https://github.com/pytorch/pytorch/issues/37113
            # Manually figure out the correct whl URL
            package_version_map = {
                (2,7): "cp27-cp27mu",
                (3,5): "cp35-cp35m",
                (3,6): "cp36-cp36m",
                (3,7): "cp37-cp37m",
                (3,8): "cp38-cp38",
            }
            platform_version = package_version_map[(sys.version_info.major, sys.version_info.minor)]

            pip_args += ["https://download.pytorch.org/whl/" + torch_cuda_string + "/torch-" + version.replace("+", "%2B") + "-" + platform_version + "-linux_x86_64.whl"]
        else:
            pip_args += ["torch==" + version]

    pip_install(pip_args)


def install_tensorflow(version):
    if "dev" in version:
        package = "tf-nightly"
    else:
        package = "tensorflow"

    if IS_GPU:
        package += "-gpu"

    pip_install([package + "==" + version])

if __name__ == '__main__':
    print("Installing tensorflow", REQUESTED_TF_VERSION, "and torch", REQUESTED_TORCH_VERSION)
    install_tensorflow(REQUESTED_TF_VERSION)
    install_pytorch(REQUESTED_TORCH_VERSION)
