def find_cuda_config()

in build_deps/toolchains/gpu/find_cuda_config.py [0:0]


def find_cuda_config():
  """Returns a dictionary of CUDA library and header file paths."""
  libraries = [argv.lower() for argv in sys.argv[1:]]
  cuda_version = os.environ.get("TF_CUDA_VERSION", "")
  base_paths = _list_from_env("TF_CUDA_PATHS",
                              _get_default_cuda_paths(cuda_version))
  base_paths = [path for path in base_paths if os.path.exists(path)]

  result = {}
  if "cuda" in libraries:
    cuda_paths = _list_from_env("CUDA_TOOLKIT_PATH", base_paths)
    result.update(_find_cuda_config(cuda_paths, cuda_version))

    cuda_version = result["cuda_version"]
    cublas_paths = base_paths
    if tuple(int(v) for v in cuda_version.split(".")) < (10, 1):
      # Before CUDA 10.1, cuBLAS was in the same directory as the toolkit.
      cublas_paths = cuda_paths
    cublas_version = os.environ.get("TF_CUBLAS_VERSION", "")
    result.update(
        _find_cublas_config(cublas_paths, cublas_version, cuda_version))

    cusolver_paths = base_paths
    if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
      cusolver_paths = cuda_paths
    cusolver_version = os.environ.get("TF_CUSOLVER_VERSION", "")
    result.update(
        _find_cusolver_config(cusolver_paths, cusolver_version, cuda_version))

    curand_paths = base_paths
    if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
      curand_paths = cuda_paths
    curand_version = os.environ.get("TF_CURAND_VERSION", "")
    result.update(
        _find_curand_config(curand_paths, curand_version, cuda_version))

    cufft_paths = base_paths
    if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
      cufft_paths = cuda_paths
    cufft_version = os.environ.get("TF_CUFFT_VERSION", "")
    result.update(_find_cufft_config(cufft_paths, cufft_version, cuda_version))

    cusparse_paths = base_paths
    if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
      cusparse_paths = cuda_paths
    cusparse_version = os.environ.get("TF_CUSPARSE_VERSION", "")
    result.update(
        _find_cusparse_config(cusparse_paths, cusparse_version, cuda_version))

  if "cudnn" in libraries:
    cudnn_paths = _get_legacy_path("CUDNN_INSTALL_PATH", base_paths)
    cudnn_version = os.environ.get("TF_CUDNN_VERSION", "")
    result.update(_find_cudnn_config(cudnn_paths, cudnn_version))

  if "nccl" in libraries:
    nccl_paths = _get_legacy_path("NCCL_INSTALL_PATH", base_paths)
    nccl_version = os.environ.get("TF_NCCL_VERSION", "")
    result.update(_find_nccl_config(nccl_paths, nccl_version))

  if "tensorrt" in libraries:
    tensorrt_paths = _get_legacy_path("TENSORRT_INSTALL_PATH", base_paths)
    tensorrt_version = os.environ.get("TF_TENSORRT_VERSION", "")
    result.update(_find_tensorrt_config(tensorrt_paths, tensorrt_version))

  for k, v in result.items():
    if k.endswith("_dir") or k.endswith("_path"):
      result[k] = _normalize_path(v)

  return result