maga_transformer/cpp/kernels/rocm/quantization_rocm.cu (514 lines of code) (raw):

#include "maga_transformer/cpp/utils/AssertUtils.h" #include "maga_transformer/cpp/kernels/rocm/quantization_rocm.h" #include "maga_transformer/cpp/cuda/reduce_kernel_utils.cuh" #include "maga_transformer/cpp/rocm/hip_utils.h" namespace rtp_llm { using namespace rocm; ///////////////////////////////////////////////////////////////////////////////////////////////// // int4 col quant /////////////////////////////////////////////////////////////////////////////// template<typename T> __global__ void perColQuantization(const T* src, const int64_t numRows, const int64_t numCols, const int64_t groupSize, uint8_t* weightPtr, half* scalePtr, half* zerosPtr, float* dbgfp = nullptr, int* dbgint = nullptr) { uint32_t colPckIdx = blockIdx.y; uint32_t rowGrpIdx = blockIdx.x; float vall = cuda_cast<float>(src[(rowGrpIdx * groupSize + threadIdx.x) * numCols + colPckIdx * 2 + 0]); float valh = cuda_cast<float>(src[(rowGrpIdx * groupSize + threadIdx.x) * numCols + colPckIdx * 2 + 1]); const float groupMaxl = blockAllReduceMax(vall); const float groupMaxh = blockAllReduceMax(valh); if (threadIdx.x == 0) { scalePtr[rowGrpIdx * numCols + colPckIdx * 2 + 0] = groupMaxl / 7.0f; scalePtr[rowGrpIdx * numCols + colPckIdx * 2 + 1] = groupMaxh / 7.0f; zerosPtr[rowGrpIdx * numCols + colPckIdx * 2 + 0] = 0; zerosPtr[rowGrpIdx * numCols + colPckIdx * 2 + 1] = 0; } const float scaleOrigQuantl = 7.f / groupMaxl; const float scaleOrigQuanth = 7.f / groupMaxh; int8_t tmpi8l = cuda_cast<int8_t>(cuda_cast<float>(vall) * scaleOrigQuantl); int8_t tmpi8h = cuda_cast<int8_t>(cuda_cast<float>(valh) * scaleOrigQuanth); uint8_t tmpu4l = tmpi8l & 0x0F; uint8_t tmpu4h = tmpi8h & 0x0F; uint8_t tmpu8 = tmpu4h; tmpu8 = tmpu8 << 4; tmpu8 = tmpu8 | tmpu4l; weightPtr[(rowGrpIdx * groupSize + threadIdx.x) * numCols / 2 + colPckIdx] = tmpu8; } template<typename T> void invokePerColQuantizationInt4x2(const T* src, const int64_t numRows, const int64_t numCols, const int64_t groupSize, uint8_t* weightPtr, half* scalePtr, half* zerosPtr, cudaStream_t stream) { assert(numRows % groupSize == 0); const dim3 block(groupSize); const dim3 grid(numRows / groupSize, numCols / 2, 1); perColQuantization<T><<<grid, block, 0, stream>>>(src, numRows, numCols, groupSize, weightPtr, scalePtr, zerosPtr); } #define INSTANTIATE_INVOKE_PER_COL_QUANTIZATION_INT4X2(T) \ template void invokePerColQuantizationInt4x2(const T* src, \ const int64_t numRows, \ const int64_t numCols, \ const int64_t groupSize, \ uint8_t* weightPtr, \ half* scalePtr, \ half* zerosPtr, \ cudaStream_t stream) INSTANTIATE_INVOKE_PER_COL_QUANTIZATION_INT4X2(float); INSTANTIATE_INVOKE_PER_COL_QUANTIZATION_INT4X2(half); #ifdef ENABLE_BF16 INSTANTIATE_INVOKE_PER_COL_QUANTIZATION_INT4X2(__nv_bfloat16); #endif ///////////////////////////////////////////////////////////////////////////////////////////////// // int4 col dequant ///////////////////////////////////////////////////////////////////////////// template<typename T> __global__ void perColDequantization(T* dst, const int64_t numRows, const int64_t numCols, const int64_t groupSize, const char4* weightPtr, const half* scalePtr, const half* zerosPtr, float* dbgfp = nullptr, int* dbgint = nullptr) { const uint8_t* pWeight = (const uint8_t*)weightPtr; uint32_t colPckIdx = blockIdx.x * blockDim.x + threadIdx.x; uint32_t rowIdx = blockIdx.y; uint32_t rowGrpIdx = rowIdx / groupSize; if (colPckIdx >= numCols / 2) return; float scalel = cuda_cast<float>(scalePtr[rowGrpIdx * numCols + colPckIdx * 2 + 0]); float scaleh = cuda_cast<float>(scalePtr[rowGrpIdx * numCols + colPckIdx * 2 + 1]); float zerosl = cuda_cast<float>(zerosPtr[rowGrpIdx * numCols + colPckIdx * 2 + 0]); float zerosh = cuda_cast<float>(zerosPtr[rowGrpIdx * numCols + colPckIdx * 2 + 0]); uint8_t tmpu8 = pWeight[rowIdx * numCols / 2 + colPckIdx]; uint8_t tmpu4l = tmpu8 & 0x0F; uint8_t tmpu4h = (tmpu8 >> 4) & 0x0F; if (tmpu4l & 0x08) tmpu4l |= 0xF0; if (tmpu4h & 0x08) tmpu4h |= 0xF0; int8_t tmpi4l = tmpu4l; int8_t tmpi4h = tmpu4h; float tmpfpl = cuda_cast<float>(tmpi4l); float tmpfph = cuda_cast<float>(tmpi4h); T vall = cuda_cast<T>(tmpfpl * scalel + zerosl); T valh = cuda_cast<T>(tmpfph * scaleh + zerosh); dst[rowIdx * numCols + colPckIdx * 2 + 0] = vall; dst[rowIdx * numCols + colPckIdx * 2 + 1] = valh; } template<typename T> void invokePerColDequantizationInt4x2(T* dst, const int64_t numRows, const int64_t numCols, const int64_t groupSize, const int8_t* weightPtr, half* scalePtr, half* zerosPtr, cudaStream_t stream) { assert(numRows % groupSize == 0); const dim3 block(numCols / 2 < 512 ? numCols / 2 : 512); const dim3 grid((numCols / 2 + block.x - 1) / block.x, numRows, 1); perColDequantization<T> <<<grid, block, 0, stream>>>(dst, numRows, numCols, groupSize, (char4*)weightPtr, scalePtr, zerosPtr); } #define INSTANTIATE_INVOKE_PER_COL_DEQUANTIZATION_INT4X2(T) \ template void invokePerColDequantizationInt4x2(T* dst, \ const int64_t numRows, \ const int64_t numCols, \ const int64_t groupSize, \ const int8_t* weightPtr, \ half* scalePtr, \ half* zerosPtr, \ cudaStream_t stream) INSTANTIATE_INVOKE_PER_COL_DEQUANTIZATION_INT4X2(float); INSTANTIATE_INVOKE_PER_COL_DEQUANTIZATION_INT4X2(half); #ifdef ENABLE_BF16 INSTANTIATE_INVOKE_PER_COL_DEQUANTIZATION_INT4X2(__nv_bfloat16); #endif ///////////////////////////////////////////////////////////////////////////////////////////////// __global__ void quantizedKernel(char4* dst, const float4* src, const int64_t sizeDiv4, const float* scalePtr) { for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4; idx += blockDim.x * gridDim.x) { const float scale = __ldg(scalePtr); char4 tmp; const float4 floatTmp = __ldg(src + idx); tmp.x = cuda_cast<int8_t>(floatTmp.x * scale); tmp.y = cuda_cast<int8_t>(floatTmp.y * scale); tmp.z = cuda_cast<int8_t>(floatTmp.z * scale); tmp.w = cuda_cast<int8_t>(floatTmp.w * scale); dst[idx] = tmp; } } __global__ void quantizedKernel(char4* dst, const half2* src, const int64_t sizeDiv4, const float* scalePtr) { for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4; idx += blockDim.x * gridDim.x) { const float scale = __ldg(scalePtr); char4 tmp; int srcId = idx << 1; const uint2 h2 = __ldg(reinterpret_cast<const uint2*>(src + srcId)); const half2 half2Tmp = reinterpret_cast<const half2&>(h2.x); const half2 half2Tmp2 = reinterpret_cast<const half2&>(h2.y); tmp.x = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp.x) * scale); tmp.y = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp.y) * scale); tmp.z = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp2.x) * scale); tmp.w = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp2.y) * scale); dst[idx] = tmp; } } template<typename T> void invokeQuantization( int8_t* dst, const T* src, const int64_t size, const float* scalePtr, cudaStream_t stream, int maxGridSize) { RTP_LLM_CHECK_WITH_INFO(size % 4 == 0, "[ERROR][invokeQuantization] size should be a multiple of 4.\n"); int numBlocks{static_cast<int>((size + 255) / 256)}; if (maxGridSize == -1) { maxGridSize = numBlocks; } dim3 grid(std::min(numBlocks, maxGridSize)); RTP_LLM_CHECK_WITH_INFO(grid.x <= maxGridSize, "[ERROR][invokeQuantization] grid max size is exceeded\n"); dim3 block(64); if (std::is_same_v<T, float>) { quantizedKernel<<<grid, block, 0, stream>>>((char4*)dst, (const float4*)src, size / 4, scalePtr); } else if (std::is_same_v<T, half>) { quantizedKernel<<<grid, block, 0, stream>>>((char4*)dst, (const half2*)src, size / 4, scalePtr); } } #define INSTANTIATE_INVOKE_QUANTIZATION(T) \ template void invokeQuantization( \ int8_t* dst, const T* src, const int64_t size, const float* scalePtr, cudaStream_t stream, int maxGridSize); INSTANTIATE_INVOKE_QUANTIZATION(float); INSTANTIATE_INVOKE_QUANTIZATION(half); #ifdef ENABLE_BF16 INSTANTIATE_INVOKE_QUANTIZATION(__nv_bfloat16); #endif template<typename T, bool IS_SMOOTHER, bool IS_SHIFT> __global__ void perTokenQuantization(int8_t* dst, const T* src, const int64_t numRows, const int64_t numCols, float* scalePtr, const float* smoother, const float* shift) { const T* srcRow = src + blockIdx.x * numCols; int8_t* dstRow = dst + blockIdx.x * numCols; T localMax = 1e-6f; for (int i = threadIdx.x; i < numCols; i += blockDim.x) { T val = srcRow[i]; if (IS_SMOOTHER) { val = cuda_cast<T>(val / cuda_cast<T>(smoother[i])); } if (IS_SHIFT) { val = cuda_cast<T>(val + cuda_cast<T>(shift[i])); } localMax = cuda_max(localMax, cuda_abs(val)); } const float rowMax = blockAllReduceMax(cuda_cast<float>(localMax)); if (threadIdx.x == 0) { scalePtr[blockIdx.x] = rowMax / 127.f; } const float scaleOrigQuant = 127.f / rowMax; for (int i = threadIdx.x; i < numCols; i += blockDim.x) { T val = srcRow[i]; if (IS_SMOOTHER) { val = val / cuda_cast<T>(smoother[i]); } if (IS_SHIFT) { val = cuda_cast<T>(val + cuda_cast<T>(shift[i])); } dstRow[i] = cuda_cast<int8_t>(cuda_cast<float>(val) * scaleOrigQuant); } } template<typename T, bool IS_SMOOTHER> void dispatch_per_token_quantization_shift(int8_t* dst, const T* src, const int64_t numRows, const int64_t numCols, float* scalePtr, const float* smoother, const float* shift, cudaStream_t stream) { // each block is responsible for a single row const dim3 block(512); const dim3 grid(numRows); if (shift != nullptr) { perTokenQuantization<T, IS_SMOOTHER, true> <<<grid, block, 0, stream>>>(dst, src, numRows, numCols, scalePtr, smoother, shift); } else { perTokenQuantization<T, IS_SMOOTHER, false> <<<grid, block, 0, stream>>>(dst, src, numRows, numCols, scalePtr, smoother, nullptr); } } template<typename T> void invokePerTokenQuantization(int8_t* dst, const T* src, const int64_t numRows, const int64_t numCols, float* scalePtr, const float* smoother, const float* shift, cudaStream_t stream) { if (smoother != nullptr) { dispatch_per_token_quantization_shift<T, true>(dst, src, numRows, numCols, scalePtr, smoother, shift, stream); } else { dispatch_per_token_quantization_shift<T, false>(dst, src, numRows, numCols, scalePtr, nullptr, shift, stream); } } #define INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(T) \ template void invokePerTokenQuantization(int8_t* dst, \ const T* src, \ const int64_t numRows, \ const int64_t numCols, \ float* scalePtr, \ const float* smoother, \ const float* shift, \ cudaStream_t stream) INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(float); INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(half); #ifdef ENABLE_BF16 INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16); #endif ///////////////////////////////////////////////////////////////////////////////////////////////// // int8 col quant /////////////////////////////////////////////////////////////////////////////// template<typename T, bool IS_SMOOTHER, bool IS_SHIFT> __global__ void perColQuantization(int8_t* dst, const T* src, const int64_t numRows, const int64_t numCols, half* scalePtr, const float* smoother, const float* shift, float* dbgfp = nullptr, int* dbgint = nullptr) { uint32_t colIdx = blockIdx.x; const T* srcCol = src + colIdx; int8_t* dstCol = dst + colIdx; T localMax = 1e-6f; for (int rowIdx = threadIdx.x; rowIdx < numRows; rowIdx += blockDim.x) { T val = srcCol[rowIdx * numCols]; if (IS_SMOOTHER) { val = cuda_cast<T>(val / cuda_cast<T>(smoother[rowIdx])); } if (IS_SHIFT) { val = cuda_cast<T>(val + cuda_cast<T>(shift[rowIdx])); } localMax = cuda_max(localMax, cuda_abs(val)); } const float colMax = blockAllReduceMax(cuda_cast<float>(localMax)); if (threadIdx.x == 0) { scalePtr[colIdx] = cuda_cast<half>(colMax / 128.f); } const float scaleOrigQuant = 128.f / colMax; for (int rowIdx = threadIdx.x; rowIdx < numRows; rowIdx += blockDim.x) { T val = srcCol[rowIdx * numCols]; if (IS_SMOOTHER) { val = val / cuda_cast<T>(smoother[rowIdx]); } if (IS_SHIFT) { val = cuda_cast<T>(val + cuda_cast<T>(shift[rowIdx])); } dstCol[rowIdx * numCols] = cuda_cast<int8_t>(cuda_cast<float>(val) * scaleOrigQuant); } } template<typename T, bool IS_SMOOTHER> void dispatch_per_col_quantization_shift(int8_t* dst, const T* src, const int64_t numRows, const int64_t numCols, half* scalePtr, const float* smoother, const float* shift, cudaStream_t stream) { // each block is responsible for a single row const dim3 block(512); const dim3 grid(numCols); if (shift != nullptr) { perColQuantization<T, IS_SMOOTHER, true> <<<grid, block, 0, stream>>>(dst, src, numRows, numCols, scalePtr, smoother, shift); } else { perColQuantization<T, IS_SMOOTHER, false> <<<grid, block, 0, stream>>>(dst, src, numRows, numCols, scalePtr, smoother, nullptr); } } template<typename T> void invokePerColQuantizationInt8(int8_t* dst, const T* src, const int64_t numRows, const int64_t numCols, half* scalePtr, const float* smoother, const float* shift, cudaStream_t stream) { if (smoother != nullptr) { dispatch_per_col_quantization_shift<T, true>(dst, src, numRows, numCols, scalePtr, smoother, shift, stream); } else { dispatch_per_col_quantization_shift<T, false>(dst, src, numRows, numCols, scalePtr, nullptr, shift, stream); } } #define INSTANTIATE_INVOKE_PER_COL_QUANTIZATION_INT8(T) \ template void invokePerColQuantizationInt8(int8_t* dst, \ const T* src, \ const int64_t numRows, \ const int64_t numCols, \ half* scalePtr, \ const float* smoother, \ const float* shift, \ cudaStream_t stream) INSTANTIATE_INVOKE_PER_COL_QUANTIZATION_INT8(float); INSTANTIATE_INVOKE_PER_COL_QUANTIZATION_INT8(half); #ifdef ENABLE_BF16 INSTANTIATE_INVOKE_PER_COL_QUANTIZATION_INT8(__nv_bfloat16); #endif ///////////////////////////////////////////////////////////////////////////////////////////////// // int8 col dequant ///////////////////////////////////////////////////////////////////////////// template<typename T, bool IS_SMOOTHER, bool IS_SHIFT> __global__ void perColDequantization(T* dst, const int8_t* src, const int64_t numRows, const int64_t numCols, const half* scalePtr, const float* smoother, const float* shift, float* dbgfp = nullptr, int* dbgint = nullptr) { uint32_t colIdx = blockIdx.x; const int8_t* srcRow = src + colIdx; T* dstRow = dst + colIdx; float scaleOrigQuant = cuda_cast<float>(scalePtr[colIdx]); if (IS_SMOOTHER) { scaleOrigQuant = scaleOrigQuant * smoother[colIdx]; } if (IS_SHIFT) { scaleOrigQuant = scaleOrigQuant - shift[colIdx]; } for (int rowIdx = threadIdx.x; rowIdx < numRows; rowIdx += blockDim.x) { uint8_t tmpi8 = srcRow[rowIdx * numCols]; T val = cuda_cast<T>(cuda_cast<float>(tmpi8) * scaleOrigQuant); if (IS_SMOOTHER) { val = val * cuda_cast<T>(smoother[rowIdx]); } if (IS_SHIFT) { val = cuda_cast<T>(val - cuda_cast<T>(shift[rowIdx])); } dstRow[rowIdx * numCols] = val; } } template<typename T, bool IS_SMOOTHER> void dispatch_per_col_dequantization_shift(T* dst, const int8_t* src, const int64_t numRows, const int64_t numCols, half* scalePtr, const float* smoother, const float* shift, cudaStream_t stream) { // each block is responsible for a single col const dim3 block(512); const dim3 grid(numCols); if (shift != nullptr) { perColDequantization<T, IS_SMOOTHER, true> <<<grid, block, 0, stream>>>(dst, src, numRows, numCols, scalePtr, smoother, shift); } else { perColDequantization<T, IS_SMOOTHER, false> <<<grid, block, 0, stream>>>(dst, src, numRows, numCols, scalePtr, smoother, nullptr); } } template<typename T> void invokePerColDequantizationInt8(T* dst, const int8_t* src, const int64_t numRows, const int64_t numCols, half* scalePtr, const float* smoother, const float* shift, cudaStream_t stream) { if (smoother != nullptr) { dispatch_per_col_dequantization_shift<T, true>(dst, src, numRows, numCols, scalePtr, smoother, shift, stream); } else { dispatch_per_col_dequantization_shift<T, false>(dst, src, numRows, numCols, scalePtr, nullptr, shift, stream); } } #define INSTANTIATE_INVOKE_PER_COL_DEQUANTIZATION_INT8(T) \ template void invokePerColDequantizationInt8(T* dst, \ const int8_t* src, \ const int64_t numRows, \ const int64_t numCols, \ half* scalePtr, \ const float* smoother, \ const float* shift, \ cudaStream_t stream) INSTANTIATE_INVOKE_PER_COL_DEQUANTIZATION_INT8(float); INSTANTIATE_INVOKE_PER_COL_DEQUANTIZATION_INT8(half); #ifdef ENABLE_BF16 INSTANTIATE_INVOKE_PER_COL_DEQUANTIZATION_INT8(__nv_bfloat16); #endif ///////////////////////////////////////////////////////////////////////////////////////////////// // int4 row dequant ///////////////////////////////////////////////////////////////////////////// template<typename T> __global__ void perRowDequantization(T* dst, const char4* src, const int64_t numRows, const int64_t numCols, const half* scalePtr, const half* zerosPtr, const int64_t groupSize, float* dbgfp = nullptr, int* dbgint = nullptr) { const uint8_t* pSrc = (const uint8_t*)src; uint32_t rowIdx = blockIdx.y; uint32_t colGrpIdx = blockIdx.x; uint32_t colGrpNum = numCols / groupSize; float scale = cuda_cast<float>(scalePtr[rowIdx * colGrpNum + colGrpIdx]); float zeros = cuda_cast<float>(zerosPtr[rowIdx * colGrpNum + colGrpIdx]); // scale = 1.0f; // zeros = 0; uint8_t tmpu8 = pSrc[rowIdx * numCols / 2 + colGrpIdx * groupSize / 2 + threadIdx.x]; uint8_t tmpu4l = tmpu8 & 0x0F; uint8_t tmpu4h = (tmpu8 >> 4) & 0x0F; if (tmpu4l & 0x08) tmpu4l |= 0xF0; if (tmpu4h & 0x08) tmpu4h |= 0xF0; int8_t tmpi4l = tmpu4l; int8_t tmpi4h = tmpu4h; float tmpfpl = cuda_cast<float>(tmpi4l); float tmpfph = cuda_cast<float>(tmpi4h); T vall = cuda_cast<T>(tmpfpl * scale); T valh = cuda_cast<T>(tmpfph * scale); dst[rowIdx * numCols + colGrpIdx * groupSize + threadIdx.x * 2 + 0] = vall; dst[rowIdx * numCols + colGrpIdx * groupSize + threadIdx.x * 2 + 1] = valh; } template<typename T> void invokePerRowDequantizationInt4x2(T* dst, const int8_t* src, const int64_t numRows, const int64_t numCols, half* scalePtr, half* zerosPtr, const int64_t groupSize, cudaStream_t stream) { const dim3 block(groupSize / 2); const dim3 grid(numCols / groupSize, numRows, 1); perRowDequantization<T> <<<grid, block, 0, stream>>>(dst, (char4*)src, numRows, numCols, scalePtr, zerosPtr, groupSize); } #define INSTANTIATE_INVOKE_PER_ROW_DEQUANTIZATION_INT4X2(T) \ template void invokePerRowDequantizationInt4x2(T* dst, \ const int8_t* src, \ const int64_t numRows, \ const int64_t numCols, \ half* scalePtr, \ half* zerosPtr, \ const int64_t groupSize, \ cudaStream_t stream) INSTANTIATE_INVOKE_PER_ROW_DEQUANTIZATION_INT4X2(float); INSTANTIATE_INVOKE_PER_ROW_DEQUANTIZATION_INT4X2(half); #ifdef ENABLE_BF16 INSTANTIATE_INVOKE_PER_ROW_DEQUANTIZATION_INT4X2(__nv_bfloat16); #endif } // namespace rtp_llm