def generate()

in deep_gemm/jit/template.py [0:0]


def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
    # Common prefix
    code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'

    # Includes
    preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
    preload_package_includes = ['"cutlass/cutlass.h"']

    assert isinstance(includes, list) or isinstance(includes, tuple)
    sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
    package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
    code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
    code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'

    # Function signature
    raw = '__raw_'
    get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
    code += f'extern "C" void launch('
    code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
    code += ') {\n'

    # Cast raw types
    code += '    // Cast raw types (if needed)\n'
    for arg_name, arg_type in arg_defs:
        if genc_map[arg_type][0] != genc_map[arg_type][1]:
            code += f'    auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'

    # Function body
    code += '\n'.join([(('    ' if line else '') + line) for line in body.split('\n')])

    # End the function
    code += '}\n\n'

    # Debug print
    if os.getenv('DG_JIT_DEBUG', None):
        print(f'Generated code:\n{code}')

    return code