in src/nv-wavenet/setup.py [0:0]
def build_extensions(self):
self._check_abi()
for extension in self.extensions:
self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H')
self._define_torch_extension_name(extension)
self._add_gnu_cpp_abi_flag(extension)
# Register .cu and .cuh as valid source extensions.
self.compiler.src_extensions += ['.cu', '.cuh']
# Save the original _compile method for later.
if self.compiler.compiler_type == 'msvc':
self.compiler._cpp_extensions += ['.cu', '.cuh']
original_compile = self.compiler.compile
original_spawn = self.compiler.spawn
else:
original_compile = self.compiler._compile
def unix_wrap_compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
# Copy before we make any modifications.
cflags = copy.deepcopy(extra_postargs)
try:
original_compiler = self.compiler.compiler_so
if _is_cuda_file(src):
nvcc = _join_cuda_home('bin', 'nvcc')
if not isinstance(nvcc, list):
nvcc = [nvcc]
self.compiler.set_executable('compiler_so', nvcc)
if isinstance(cflags, dict):
cflags = cflags['nvcc']
cflags = ['--compiler-options', "'-fPIC'"] + cflags
elif isinstance(cflags, dict):
cflags = cflags['cxx']
# NVCC does not allow multiple -std to be passed, so we avoid
# overriding the option if the user explicitly passed it.
if not any(flag.startswith('-std=') for flag in cflags):
cflags.append('-std=c++11')
original_compile(obj, src, ext, cc_args, cflags, pp_opts)
finally:
# Put the original compiler back in place.
self.compiler.set_executable('compiler_so', original_compiler)
def win_wrap_compile(sources,
output_dir=None,
macros=None,
include_dirs=None,
debug=0,
extra_preargs=None,
extra_postargs=None,
depends=None):
self.cflags = copy.deepcopy(extra_postargs)
extra_postargs = None
def spawn(cmd):
# Using regex to match src, obj and include files
src_regex = re.compile('/T(p|c)(.*)')
src_list = [
m.group(2) for m in (src_regex.match(elem) for elem in cmd)
if m
]
obj_regex = re.compile('/Fo(.*)')
obj_list = [
m.group(1) for m in (obj_regex.match(elem) for elem in cmd)
if m
]
include_regex = re.compile(r'((\-|\/)I.*)')
include_list = [
m.group(1)
for m in (include_regex.match(elem) for elem in cmd) if m
]
if len(src_list) >= 1 and len(obj_list) >= 1:
src = src_list[0]
obj = obj_list[0]
if _is_cuda_file(src):
nvcc = _join_cuda_home('bin', 'nvcc')
if isinstance(self.cflags, dict):
cflags = self.cflags['nvcc']
elif isinstance(self.cflags, list):
cflags = self.cflags
else:
cflags = []
cmd = [
nvcc, '-c', src, '-o', obj, '-Xcompiler',
'/wd4819', '-Xcompiler', '/MD'
] + include_list + cflags
elif isinstance(self.cflags, dict):
cflags = self.cflags['cxx'] + ['/MD']
cmd += cflags
elif isinstance(self.cflags, list):
cflags = self.cflags + ['/MD']
cmd += cflags
return original_spawn(cmd)
try:
self.compiler.spawn = spawn
return original_compile(sources, output_dir, macros,
include_dirs, debug, extra_preargs,
extra_postargs, depends)
finally:
self.compiler.spawn = original_spawn
# Monkey-patch the _compile method.
if self.compiler.compiler_type == 'msvc':
self.compiler.compile = win_wrap_compile
else:
self.compiler._compile = unix_wrap_compile
build_ext.build_extensions(self)