def map_ctype()

in deep_gemm/jit/template.py [0:0]


def map_ctype(value: Any) -> Any:
    if hasattr(value, 'data_ptr'):
        if value.dtype == torch.int:
            return ctypes.c_void_p(value.data_ptr())
        elif value.dtype == torch.float:
            return ctypes.c_void_p(value.data_ptr())
        elif value.dtype == torch.bfloat16:
            return ctypes.c_void_p(value.data_ptr())
        elif value.dtype == torch.float16:
            return ctypes.c_void_p(value.data_ptr())
        elif value.dtype == torch.float8_e4m3fn:
            return ctypes.c_void_p(value.data_ptr())
        else:
            return ctypes.c_void_p(value.data_ptr())

    if hasattr(value, 'cuda_stream'):
        return ctypes.c_void_p(value.cuda_stream)

    if isinstance(value, bool):
        return ctypes.c_bool(value)
    elif isinstance(value, int):
        return ctypes.c_int(value)
    elif isinstance(value, float):
        return ctypes.c_float(value)

    return ctype_map[type(value)](value)