def _find_cublas_config()

in gpu/find_cuda_config.py [0:0]


def _find_cublas_config(base_paths, required_version, cuda_version):

  if _at_least_version(cuda_version, "10.1"):

    def get_header_version(path):
      version = (
          _get_header_version(path, name)
          for name in ("CUBLAS_VER_MAJOR", "CUBLAS_VER_MINOR",
                       "CUBLAS_VER_PATCH"))
      return ".".join(version)

    header_path, header_version = _find_header(base_paths, "cublas_api.h",
                                               required_version,
                                               get_header_version)
    # cuBLAS uses the major version only.
    cublas_version = header_version.split(".")[0]

    if not _matches_version(cuda_version, cublas_version):
      raise ConfigError("cuBLAS version %s does not match CUDA version %s" %
                        (cublas_version, cuda_version))

  else:
    # There is no version info available before CUDA 10.1, just find the file.
    header_path = _find_file(base_paths, _header_paths(), "cublas_api.h")
    # cuBLAS version is the same as CUDA version (x.y).
    cublas_version = required_version

  library_path = _find_library(base_paths, "cublas", cublas_version)

  return {
      "cublas_include_dir": os.path.dirname(header_path),
      "cublas_library_dir": os.path.dirname(library_path),
  }