optimum/quanto/library/extensions/cuda/__init__.py (152 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from ..extension import Extension, register_extension
__all__ = []
def get_max_cuda_arch():
"""Select the maximum CUDA arch supported
This is a combination of the CUDA and pytorch version and all detected devices capabilities.
"""
capability_list = []
supported_sm = [int(arch.split("_")[1]) for arch in torch.cuda.get_arch_list() if "sm_" in arch]
if supported_sm:
max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)
for i in range(torch.cuda.device_count()):
capability = torch.cuda.get_device_capability(i)
# Capability of the device may be higher than what's supported by the user's
# NVCC, causing compilation error. User's NVCC is expected to match the one
# used to build pytorch, so we use the maximum supported capability of pytorch
# to clamp the capability.
capability = min(max_supported_sm, capability)
if capability not in capability_list:
capability_list.append(capability)
max_capability = max(sorted(capability_list)) if len(capability_list) > 0 else (0, 0)
return f"{max_capability[0]}{max_capability[1]}0"
extra_cflags = ["-g", "-O3"]
extra_cuda_cflags = [
"--expt-extended-lambda",
"--use_fast_math",
]
# We need to know the minimum CUDA Arch to select only the relevant kernels
# but we cannot rely on __CUDA_ARCH__ as it is not set in host code (only on device code)
quanto_cuda_arch = get_max_cuda_arch()
extra_cuda_cflags += [f"-DQUANTO_CUDA_ARCH={quanto_cuda_arch}"]
module_path = os.path.dirname(__file__)
sources = [
"unpack.cu",
"awq/v2/gemm_cuda.cu",
"awq/v2/gemv_cuda.cu",
"marlin/fp8_marlin.cu",
"marlin/gptq_marlin_repack.cu",
"marlin/marlin_cuda.cpp",
"marlin/marlin_cuda_kernel.cu",
"pybind_module.cpp",
]
ext = Extension(
"quanto_cuda",
root_dir=os.path.dirname(__file__),
sources=sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
)
register_extension(ext)
@torch.library.impl("quanto::unpack", ["CUDA"])
def unpack_cuda(t: torch.Tensor, bits: int):
return ext.lib.unpack(t, bits)
torch.library.define(
"quanto::gemm_f16i4_awq",
"(Tensor input,"
" Tensor other,"
" Tensor other_scale,"
" Tensor other_shift,"
" int rows,"
" int out_cols,"
" int in_cols,"
" int bits,"
" int group_size)"
" -> Tensor",
)
@torch.library.impl("quanto::gemm_f16i4_awq", ["CUDA"])
def gemm_f16i4_awq(
input: torch.Tensor,
other: torch.Tensor,
scales: torch.Tensor,
shift: torch.Tensor,
rows: int,
out_cols: int,
in_cols: int,
bits: int,
group_size: int,
):
assert out_cols >= 128
assert input.dtype == torch.float16
assert input.numel() == rows * in_cols
assert other.dtype == torch.int16
assert scales.dtype == torch.float16
assert scales.shape[-1] == out_cols
assert shift.dtype == torch.float16
assert shift.shape[-1] == out_cols
assert bits == 4
assert group_size == 128
if rows < 8:
return ext.lib.awq_v2_gemv_f16i4(input, other, scales, shift, rows, out_cols, in_cols, group_size)
return ext.lib.awq_v2_gemm_f16i4(input, other, scales, shift)
torch.library.define(
"quanto::gemm_f16f8_marlin",
"(Tensor a,"
"Tensor b_q_weight,"
"Tensor b_scales,"
"Tensor workspace,"
"int num_bits,"
"int size_m,"
"int size_n,"
"int size_k)"
" -> Tensor",
)
@torch.library.impl("quanto::gemm_f16f8_marlin", ["CUDA"])
def fp8_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
assert b_scales.dtype == torch.float16 or b_scales.dtype == torch.bfloat16
assert b_q_weight.dim() == 2
assert b_q_weight.dtype == torch.int32
return ext.lib.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k)
torch.library.define(
"quanto::pack_fp8_marlin",
"(Tensor b_q_weight, Tensor perm, int size_k, int size_n, int num_bits) -> Tensor",
)
@torch.library.impl("quanto::pack_fp8_marlin", ["CUDA"])
def gptq_marlin_repack(
b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
assert b_q_weight.dim() == 2
assert b_q_weight.dtype == torch.int32
return ext.lib.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
torch.library.define(
"quanto::gemm_f16i4_marlin",
"(Tensor input, Tensor other, Tensor other_scale, Tensor other_shift, Tensor workspace) -> Tensor",
)
@torch.library.impl("quanto::gemm_f16i4_marlin", ["CUDA"])
def gemm_f16i4_marlin(
input: torch.Tensor, other: torch.Tensor, scales: torch.Tensor, zeropoint: torch.Tensor, workspace: torch.Tensor
) -> torch.Tensor:
assert input.dtype == torch.float16
assert other.dtype == torch.int32
assert scales.dtype == torch.float16
assert zeropoint.dtype == torch.float16
assert workspace.dtype == torch.int32
output = torch.empty(
input.shape[:-1] + (scales.shape[1],),
dtype=input.dtype,
device=input.device,
)
ext.lib.marlin_gemm_f16i4(
input.reshape((-1, input.shape[-1])),
other,
output.reshape((-1, output.shape[-1])),
scales,
zeropoint,
workspace,
-1,
-1,
-1,
16,
)
return output