def prepare_extension_options()

in setup.py [0:0]


def prepare_extension_options():
    with_cuda, with_nccl, with_test, with_debug = parse_compile_options()

    generate_singa_config(with_cuda, with_nccl)
    generate_proto_files()

    link_libs = ['glog', 'protobuf', 'openblas', 'dnnl']

    sources = path_to_str([
        *list((SINGA_SRC / 'core').rglob('*.cc')), *list(
            (SINGA_SRC / 'model/operation').glob('*.cc')), *list(
                (SINGA_SRC / 'utils').glob('*.cc')),
        SINGA_SRC / 'proto/core.pb.cc', SINGA_SRC / 'api/singa.i'
    ])
    include_dirs = path_to_str([
        SINGA_HDR, SINGA_HDR / 'singa/proto',
        np.get_include(), '/usr/include', '/usr/include/openblas',
        '/usr/local/include'
    ])

    try:
        np_include = np.get_include()
    except AttributeError:
        np_include = np.get_numpy_include()
    include_dirs.append(np_include)

    library_dirs = []  # path_to_str(['/usr/lib64', '/usr/local/lib'])

    if with_cuda:
        link_libs.extend(['cudart', 'cudnn', 'curand', 'cublas', 'cnmem'])
        include_dirs.append('/usr/local/cuda/include')
        library_dirs.append('/usr/local/cuda/lib64')
        sources.append(str(SINGA_SRC / 'core/tensor/math_kernel.cu'))
        if with_nccl:
            link_libs.extend(['nccl', 'cusparse', 'mpicxx', 'mpi'])
            sources.append(str(SINGA_SRC / 'io/communicator.cc'))
    # print(link_libs, extra_libs)

    libraries = link_libs
    runtime_library_dirs = ['.'] + library_dirs
    extra_compile_args = {'gcc': get_cpp_flags()}

    if with_cuda:
        # compute_35 and compute_50 are removed because 1. they do not support half float;
        # 2. google colab's GPU has been updated from K80 (compute_35) to T4 (compute_75).
        cuda9_gencode = (' -gencode arch=compute_60,code=sm_60'
                         ' -gencode arch=compute_70,code=sm_70')
        cuda10_gencode = ' -gencode arch=compute_75,code=sm_75'
        cuda11_gencode = ' -gencode arch=compute_80,code=sm_80'
        cuda9_ptx = ' -gencode arch=compute_70,code=compute_70'
        cuda10_ptx = ' -gencode arch=compute_75,code=compute_75'
        cuda11_ptx = ' -gencode arch=compute_80,code=compute_80'
        if cuda_major >= 11:
            gencode = cuda9_gencode + cuda10_gencode + cuda11_gencode + cuda11_ptx
        elif cuda_major >= 10:
            gencode = cuda9_gencode + cuda10_gencode + cuda10_ptx
        elif cuda_major >= 9:
            gencode = cuda9_gencode + cuda9_ptx
        else:
            raise CompileError(
                'CUDA version must be >=9.0, the current version is {}'.format(
                    cuda_major))

        extra_compile_args['nvcc'] = shlex.split(gencode) + [
            '-Xcompiler', '-fPIC'
        ]
    options = {
        'sources': sources,
        'include_dirs': include_dirs,
        'library_dirs': library_dirs,
        'libraries': libraries,
        'runtime_library_dirs': runtime_library_dirs,
        'extra_compile_args': extra_compile_args
    }

    return options