in build/install_frameworks.py [0:0]
def install_pytorch(version):
"""
:param version: The version of torch. This can be something like "1.2.0" or
"1.1.0.dev20190601"
"""
pip_args = []
# Get the torch cuda string (e.g. cpu, cu90, cu92, cu100)
torch_cuda_string = "cu{}".format(CUDA_VERSION.replace(".", "")) if IS_GPU else "cpu"
# TODO(vip): Fix this once we have a better way of dealing with CUDA 11.2
if torch_cuda_string == "cu1121":
torch_cuda_string = "cpu"
# The base version of torch (e.g. 1.2.0)
version_base = None
# If this is a nightly build, what's the date (e.g. 20190809)
version_date = None
# Get the version info
if "dev" in version:
version_base, version_date = version.split(".dev")
else:
version_base = version
if version_date != None:
# This is a nightly build
pip_args += ["-f", "https://download.pytorch.org/whl/nightly/" + torch_cuda_string + "/torch_nightly.html"]
else:
# This is a stable build
pip_args += ["-f", "https://download.pytorch.org/whl/torch_stable.html"]
# Mac builds do not have the cuda string as part of the version
if not IS_MAC:
# If this is the 1.2.0 stable release or it's a nightly build after they started adding the cuda string to the packages
if (version_base == "1.2.0" and version_date is None) or (version_date != None and int(version_date) > 20190723):
# For CUDA 10 builds, they don't add `cu100` to the version string
if torch_cuda_string != "cu100":
version += "+" + torch_cuda_string
# If this is the 1.3.0 or 1.4.0 stable release
if (version_base == "1.3.0" or version_base == "1.4.0") and version_date is None:
# They changed the default from cuda 10.0 to cuda 10.1
# For CUDA 10.1 builds, they don't add `cu101` to the version string
if torch_cuda_string != "cu101":
version += "+" + torch_cuda_string
# If this is the 1.5.0, 1.6.0, or 1.7.0 stable release
if version_base in ["1.5.0", "1.6.0", "1.7.0"] and version_date is None:
# They changed the default from cuda 10.1 to cuda 10.2
# For CUDA 10.2 builds, they don't add `cu102` to the version string
if torch_cuda_string != "cu102":
version += "+" + torch_cuda_string
# For 1.8.1 and 1.9.0, they always include the cuda version in the version string
if version_base in ["1.8.1", "1.9.0", "1.10.2"] and version_date is None:
version += "+" + torch_cuda_string
# The Mac 1.3.0 stable release doesn't exist in `torch_stable.html`
# Use 1.3.0.post2 instead
if IS_MAC and version_base == "1.3.0" and version_date is None:
version = "1.3.0.post2"
if version_date != None:
if int(version_date) >= 20190802:
pip_args += ["torch==" + version]
else:
pip_args += ["torch_nightly==" + version]
else:
if IS_GPU and (version_base == "1.1.0" or version_base == "1.4.0" or version_base == "1.5.0"):
# See https://github.com/pytorch/pytorch/issues/37113
# Manually figure out the correct whl URL
package_version_map = {
(2,7): "cp27-cp27mu",
(3,5): "cp35-cp35m",
(3,6): "cp36-cp36m",
(3,7): "cp37-cp37m",
(3,8): "cp38-cp38",
}
platform_version = package_version_map[(sys.version_info.major, sys.version_info.minor)]
pip_args += ["https://download.pytorch.org/whl/" + torch_cuda_string + "/torch-" + version.replace("+", "%2B") + "-" + platform_version + "-linux_x86_64.whl"]
else:
pip_args += ["torch==" + version]
pip_install(pip_args)