in sam2/utils/misc.py [0:0]
def get_sdpa_settings():
if torch.cuda.is_available():
old_gpu = torch.cuda.get_device_properties(0).major < 7
# only use Flash Attention on Ampere (8.0) or newer GPUs
use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
if not use_flash_attn:
warnings.warn(
"Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
category=UserWarning,
stacklevel=2,
)
# keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
# available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
if pytorch_version < (2, 2):
warnings.warn(
f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
"Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
category=UserWarning,
stacklevel=2,
)
math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
else:
old_gpu = True
use_flash_attn = False
math_kernel_on = True
return old_gpu, use_flash_attn, math_kernel_on