def get_extensions()

in setup.py [0:0]


def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(
        this_dir, "xformers", "components", "attention", "csrc"
    )

    main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))

    source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + glob.glob(
        os.path.join(extensions_dir, "autograd", "*.cpp")
    )

    sources = main_file + source_cpu
    source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))

    sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")

    extension = CppExtension

    define_macros = []

    extra_compile_args = {"cxx": ["-O3"]}
    if sys.platform == "win32":
        define_macros += [("xformers_EXPORTS", None)]
        extra_compile_args["cxx"].append("/MP")
    elif "OpenMP not found" not in torch.__config__.parallel_info():
        extra_compile_args["cxx"].append("-fopenmp")

    include_dirs = [extensions_dir]

    if (torch.cuda.is_available() and ((CUDA_HOME is not None))) or os.getenv(
        "FORCE_CUDA", "0"
    ) == "1":
        extension = CUDAExtension
        sources += source_cuda
        include_dirs += [sputnik_dir]
        nvcc_flags = os.getenv("NVCC_FLAGS", "")
        if nvcc_flags == "":
            nvcc_flags = []
        else:
            nvcc_flags = nvcc_flags.split(" ")
        extra_compile_args["nvcc"] = nvcc_flags

    sources = [os.path.join(extensions_dir, s) for s in sources]

    ext_modules = [
        extension(
            "xformers._C",
            sorted(sources),
            include_dirs=include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args,
        )
    ]

    return ext_modules