backends/python/server/text_embeddings_server/utils/device.py (62 lines of code) (raw):
import os
from loguru import logger
import importlib.metadata
import importlib.util
from packaging import version
import torch
import subprocess
ALLOW_REDUCED_PRECISION = os.getenv(
"ALLOW_REDUCED_PRECISION_FP16_BF16", "true"
).lower() in [
"true",
"1",
]
def _is_ipex_available():
def get_major_and_minor_from_version(full_version):
return (
str(version.parse(full_version).major)
+ "."
+ str(version.parse(full_version).minor)
)
_torch_version = importlib.metadata.version("torch")
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
return False
_ipex_version = "N/A"
try:
_ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
except importlib.metadata.PackageNotFoundError:
return False
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
if torch_major_and_minor != ipex_major_and_minor:
logger.warning(
f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
)
return False
return True
def is_hpu() -> bool:
is_hpu_available = True
try:
subprocess.run(["hl-smi"], capture_output=True, check=True)
except:
is_hpu_available = False
return is_hpu_available
def use_ipex() -> bool:
value = os.environ.get("USE_IPEX", "True").lower()
return value in ["true", "1"] and _is_ipex_available()
def get_device():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
elif is_hpu():
import habana_frameworks.torch.core as htcore
# WA for perf degradation from pytorch 2.5
if ALLOW_REDUCED_PRECISION:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if hasattr(torch, "hpu") and torch.hpu.is_available(): # type: ignore
device = torch.device("hpu")
elif use_ipex():
import intel_extension_for_pytorch as ipex
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
return device