def get_plugin()

in stylegan2_ada_pytorch/torch_utils/custom_ops.py [0:0]


def get_plugin(module_name, sources, **build_kwargs):
    assert verbosity in ["none", "brief", "full"]

    # Already cached?
    if module_name in _cached_plugins:
        return _cached_plugins[module_name]

    # Print status.
    if verbosity == "full":
        print(f'Setting up PyTorch plugin "{module_name}"...')
    elif verbosity == "brief":
        print(f'Setting up PyTorch plugin "{module_name}"... ', end="", flush=True)

    try:  # pylint: disable=too-many-nested-blocks
        # Make sure we can find the necessary compiler binaries.
        if os.name == "nt" and os.system("where cl.exe >nul 2>nul") != 0:
            compiler_bindir = _find_compiler_bindir()
            if compiler_bindir is None:
                raise RuntimeError(
                    f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".'
                )
            os.environ["PATH"] += ";" + compiler_bindir

        # Compile and load.
        verbose_build = verbosity == "full"

        # Incremental build md5sum trickery.  Copies all the input source files
        # into a cached build directory under a combined md5 digest of the input
        # source files.  Copying is done only if the combined digest has changed.
        # This keeps input file timestamps and filenames the same as in previous
        # extension builds, allowing for fast incremental rebuilds.
        #
        # This optimization is done only in case all the source files reside in
        # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
        # environment variable is set (we take this as a signal that the user
        # actually cares about this.)
        source_dirs_set = set(os.path.dirname(source) for source in sources)
        if len(source_dirs_set) == 1 and ("TORCH_EXTENSIONS_DIR" in os.environ):
            all_source_files = sorted(
                list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())
            )

            # Compute a combined hash digest for all source files in the same
            # custom op directory (usually .cu, .cpp, .py and .h files).
            hash_md5 = hashlib.md5()
            for src in all_source_files:
                with open(src, "rb") as f:
                    hash_md5.update(f.read())
            build_dir = torch.utils.cpp_extension._get_build_directory(
                module_name, verbose=verbose_build
            )  # pylint: disable=protected-access
            digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())

            if not os.path.isdir(digest_build_dir):
                os.makedirs(digest_build_dir, exist_ok=True)
                baton = FileBaton(os.path.join(digest_build_dir, "lock"))
                if baton.try_acquire():
                    try:
                        for src in all_source_files:
                            shutil.copyfile(
                                src,
                                os.path.join(digest_build_dir, os.path.basename(src)),
                            )
                    finally:
                        baton.release()
                else:
                    # Someone else is copying source files under the digest dir,
                    # wait until done and continue.
                    baton.wait()
            digest_sources = [
                os.path.join(digest_build_dir, os.path.basename(x)) for x in sources
            ]
            torch.utils.cpp_extension.load(
                name=module_name,
                build_directory=build_dir,
                verbose=verbose_build,
                sources=digest_sources,
                **build_kwargs,
            )
        else:
            torch.utils.cpp_extension.load(
                name=module_name, verbose=verbose_build, sources=sources, **build_kwargs
            )
        module = importlib.import_module(module_name)

    except:
        if verbosity == "brief":
            print("Failed!")
        raise

    # Print status and add to cache.
    if verbosity == "full":
        print(f'Done setting up PyTorch plugin "{module_name}".')
    elif verbosity == "brief":
        print("Done.")
    _cached_plugins[module_name] = module
    return module