optimum/quanto/library/extensions/extension.py (54 lines of code) (raw):
import os
import shutil
import warnings
from typing import List
import torch
from torch.utils.cpp_extension import load
__all__ = ["is_extension_available", "get_extension"]
class Extension(object):
def __init__(
self,
name: str,
root_dir: str,
sources: List[str],
extra_cflags: List[str] = None,
extra_cuda_cflags: List[str] = None,
):
self.name = name
self.sources = [f"{root_dir}/{source}" for source in sources]
self.extra_cflags = extra_cflags
self.extra_cuda_cflags = extra_cuda_cflags
self.build_directory = os.path.join(root_dir, "build")
self._lib = None
@property
def lib(self):
if self._lib is None:
# We only load the extension when the lib is required
version_file = os.path.join(self.build_directory, "pytorch_version.txt")
if os.path.exists(version_file):
# The extension has already been built: check the torch version for which it was built
with open(version_file, "r") as f:
pytorch_build_version = f.read().rstrip()
if pytorch_build_version != torch.__version__:
shutil.rmtree(self.build_directory)
warnings.warn(
f"{self.name} was compiled with pytorch {pytorch_build_version}, but {torch.__version__} is installed: it will be recompiled."
)
os.makedirs(self.build_directory, exist_ok=True)
self._lib = load(
name=self.name,
sources=self.sources,
extra_cflags=self.extra_cflags,
extra_cuda_cflags=self.extra_cuda_cflags,
build_directory=self.build_directory,
)
if not os.path.exists(version_file):
with open(version_file, "w") as f:
f.write(torch.__version__)
return self._lib
_extensions = {}
def register_extension(extension: Extension):
assert extension.name not in _extensions
_extensions[extension.name] = extension
def get_extension(extension_type: str):
"""Get an extension
Args:
extension_type (`str`):
The extension type.
Returns:
The corresponding extension.
"""
return _extensions[extension_type]
def is_extension_available(extension_type: str):
"""Check is an extension is available
Args:
extension_type (`str`):
The extension type.
Returns:
True if the extension is available.
"""
return extension_type in _extensions