in gpu/find_cuda_config.py [0:0]
def _find_cublas_config(base_paths, required_version, cuda_version):
if _at_least_version(cuda_version, "10.1"):
def get_header_version(path):
version = (
_get_header_version(path, name)
for name in ("CUBLAS_VER_MAJOR", "CUBLAS_VER_MINOR",
"CUBLAS_VER_PATCH"))
return ".".join(version)
header_path, header_version = _find_header(base_paths, "cublas_api.h",
required_version,
get_header_version)
# cuBLAS uses the major version only.
cublas_version = header_version.split(".")[0]
if not _matches_version(cuda_version, cublas_version):
raise ConfigError("cuBLAS version %s does not match CUDA version %s" %
(cublas_version, cuda_version))
else:
# There is no version info available before CUDA 10.1, just find the file.
header_path = _find_file(base_paths, _header_paths(), "cublas_api.h")
# cuBLAS version is the same as CUDA version (x.y).
cublas_version = required_version
library_path = _find_library(base_paths, "cublas", cublas_version)
return {
"cublas_include_dir": os.path.dirname(header_path),
"cublas_library_dir": os.path.dirname(library_path),
}