backends/python/server/text_embeddings_server/utils/device.py (62 lines of code) (raw):

import os from loguru import logger import importlib.metadata import importlib.util from packaging import version import torch import subprocess ALLOW_REDUCED_PRECISION = os.getenv( "ALLOW_REDUCED_PRECISION_FP16_BF16", "true" ).lower() in [ "true", "1", ] def _is_ipex_available(): def get_major_and_minor_from_version(full_version): return ( str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) ) _torch_version = importlib.metadata.version("torch") if importlib.util.find_spec("intel_extension_for_pytorch") is None: return False _ipex_version = "N/A" try: _ipex_version = importlib.metadata.version("intel_extension_for_pytorch") except importlib.metadata.PackageNotFoundError: return False torch_major_and_minor = get_major_and_minor_from_version(_torch_version) ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) if torch_major_and_minor != ipex_major_and_minor: logger.warning( f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." ) return False return True def is_hpu() -> bool: is_hpu_available = True try: subprocess.run(["hl-smi"], capture_output=True, check=True) except: is_hpu_available = False return is_hpu_available def use_ipex() -> bool: value = os.environ.get("USE_IPEX", "True").lower() return value in ["true", "1"] and _is_ipex_available() def get_device(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda") elif is_hpu(): import habana_frameworks.torch.core as htcore # WA for perf degradation from pytorch 2.5 if ALLOW_REDUCED_PRECISION: torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) if hasattr(torch, "hpu") and torch.hpu.is_available(): # type: ignore device = torch.device("hpu") elif use_ipex(): import intel_extension_for_pytorch as ipex if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("xpu") return device