in stylegan2_ada_pytorch/torch_utils/custom_ops.py [0:0]
def get_plugin(module_name, sources, **build_kwargs):
assert verbosity in ["none", "brief", "full"]
# Already cached?
if module_name in _cached_plugins:
return _cached_plugins[module_name]
# Print status.
if verbosity == "full":
print(f'Setting up PyTorch plugin "{module_name}"...')
elif verbosity == "brief":
print(f'Setting up PyTorch plugin "{module_name}"... ', end="", flush=True)
try: # pylint: disable=too-many-nested-blocks
# Make sure we can find the necessary compiler binaries.
if os.name == "nt" and os.system("where cl.exe >nul 2>nul") != 0:
compiler_bindir = _find_compiler_bindir()
if compiler_bindir is None:
raise RuntimeError(
f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".'
)
os.environ["PATH"] += ";" + compiler_bindir
# Compile and load.
verbose_build = verbosity == "full"
# Incremental build md5sum trickery. Copies all the input source files
# into a cached build directory under a combined md5 digest of the input
# source files. Copying is done only if the combined digest has changed.
# This keeps input file timestamps and filenames the same as in previous
# extension builds, allowing for fast incremental rebuilds.
#
# This optimization is done only in case all the source files reside in
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
# environment variable is set (we take this as a signal that the user
# actually cares about this.)
source_dirs_set = set(os.path.dirname(source) for source in sources)
if len(source_dirs_set) == 1 and ("TORCH_EXTENSIONS_DIR" in os.environ):
all_source_files = sorted(
list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())
)
# Compute a combined hash digest for all source files in the same
# custom op directory (usually .cu, .cpp, .py and .h files).
hash_md5 = hashlib.md5()
for src in all_source_files:
with open(src, "rb") as f:
hash_md5.update(f.read())
build_dir = torch.utils.cpp_extension._get_build_directory(
module_name, verbose=verbose_build
) # pylint: disable=protected-access
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
if not os.path.isdir(digest_build_dir):
os.makedirs(digest_build_dir, exist_ok=True)
baton = FileBaton(os.path.join(digest_build_dir, "lock"))
if baton.try_acquire():
try:
for src in all_source_files:
shutil.copyfile(
src,
os.path.join(digest_build_dir, os.path.basename(src)),
)
finally:
baton.release()
else:
# Someone else is copying source files under the digest dir,
# wait until done and continue.
baton.wait()
digest_sources = [
os.path.join(digest_build_dir, os.path.basename(x)) for x in sources
]
torch.utils.cpp_extension.load(
name=module_name,
build_directory=build_dir,
verbose=verbose_build,
sources=digest_sources,
**build_kwargs,
)
else:
torch.utils.cpp_extension.load(
name=module_name, verbose=verbose_build, sources=sources, **build_kwargs
)
module = importlib.import_module(module_name)
except:
if verbosity == "brief":
print("Failed!")
raise
# Print status and add to cache.
if verbosity == "full":
print(f'Done setting up PyTorch plugin "{module_name}".')
elif verbosity == "brief":
print("Done.")
_cached_plugins[module_name] = module
return module