def compute_capability_args()

in modules/SwissArmyTransformer/sat/ops/ops_builder/builder.py [0:0]


    def compute_capability_args(self, cross_compile_archs=None):
        """
        Returns nvcc compute capability compile flags.

        1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
        2. If neither is set default compute capabilities will be used
        3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX

        Format:

        - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:

        TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
        TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...

        - `cross_compile_archs` uses ; separator.

        """
        ccs = []
        if self.jit_mode:
            # Compile for underlying architectures since we know those at runtime
            for i in range(torch.cuda.device_count()):
                CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
                cc = f"{CC_MAJOR}.{CC_MINOR}"
                if cc not in ccs:
                    ccs.append(cc)
            ccs = sorted(ccs)
            ccs[-1] += '+PTX'
        else:
            # Cross-compile mode, compile for various architectures
            # env override takes priority
            cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
            if cross_compile_archs_env is not None:
                if cross_compile_archs is not None:
                    print(
                        f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`"
                    )
                cross_compile_archs = cross_compile_archs_env.replace(' ', ';')
            else:
                if cross_compile_archs is None:
                    cross_compile_archs = get_default_compute_capabilities()
            ccs = cross_compile_archs.split(';')

        ccs = self.filter_ccs(ccs)
        if len(ccs) == 0:
            raise RuntimeError(
                f"Unable to load {self.name} op due to no compute capabilities remaining after filtering")

        args = []
        for cc in ccs:
            num = cc[0] + cc[2]
            args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
            if cc.endswith('+PTX'):
                args.append(f'-gencode=arch=compute_{num},code=compute_{num}')

        return args