in modules/SwissArmyTransformer/sat/ops/ops_builder/builder.py [0:0]
def compute_capability_args(self, cross_compile_archs=None):
"""
Returns nvcc compute capability compile flags.
1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
2. If neither is set default compute capabilities will be used
3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
Format:
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
- `cross_compile_archs` uses ; separator.
"""
ccs = []
if self.jit_mode:
# Compile for underlying architectures since we know those at runtime
for i in range(torch.cuda.device_count()):
CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
cc = f"{CC_MAJOR}.{CC_MINOR}"
if cc not in ccs:
ccs.append(cc)
ccs = sorted(ccs)
ccs[-1] += '+PTX'
else:
# Cross-compile mode, compile for various architectures
# env override takes priority
cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
if cross_compile_archs_env is not None:
if cross_compile_archs is not None:
print(
f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`"
)
cross_compile_archs = cross_compile_archs_env.replace(' ', ';')
else:
if cross_compile_archs is None:
cross_compile_archs = get_default_compute_capabilities()
ccs = cross_compile_archs.split(';')
ccs = self.filter_ccs(ccs)
if len(ccs) == 0:
raise RuntimeError(
f"Unable to load {self.name} op due to no compute capabilities remaining after filtering")
args = []
for cc in ccs:
num = cc[0] + cc[2]
args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
if cc.endswith('+PTX'):
args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
return args