bitsandbytes/cuda_specs.py (27 lines of code) (raw):

import dataclasses from typing import List, Optional, Tuple import torch @dataclasses.dataclass(frozen=True) class CUDASpecs: highest_compute_capability: Tuple[int, int] cuda_version_string: str cuda_version_tuple: Tuple[int, int] @property def has_cublaslt(self) -> bool: return self.highest_compute_capability >= (7, 5) def get_compute_capabilities() -> List[Tuple[int, int]]: return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count())) def get_cuda_version_tuple() -> Tuple[int, int]: # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION major, minor = map(int, torch.version.cuda.split(".")) return major, minor def get_cuda_version_string() -> str: major, minor = get_cuda_version_tuple() return f"{major}{minor}" def get_cuda_specs() -> Optional[CUDASpecs]: if not torch.cuda.is_available(): return None return CUDASpecs( highest_compute_capability=(get_compute_capabilities()[-1]), cuda_version_string=(get_cuda_version_string()), cuda_version_tuple=get_cuda_version_tuple(), )