maga_transformer/cpp/cuda/cuda_fp8_utils.cu (494 lines of code) (raw):
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cuda_fp8_utils.h"
#include "maga_transformer/cpp/cuda/reduce_kernel_utils.cuh"
#include "maga_transformer/cpp/cuda/cuda_type_utils.cuh"
#include <algorithm>
#include <cstdio>
#include <cuda_fp16.h>
#include <limits>
#include <type_traits>
using namespace rtp_llm;
namespace tensorrt_llm
{
namespace common
{
#ifdef ENABLE_FP8
constexpr int CTA_SIZE = 256;
template <bool QUANTIZE>
__inline__ __device__ float scale(float a, float b)
{
return QUANTIZE ? a / b : a * b;
}
template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
{
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
{
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i % lda])));
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i / lda])));
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[0])));
}
}
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
dim3 grid(1024);
dim3 block(CTA_SIZE);
if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
scaleMatrix<QuantizeMode::PER_CHANNEL, true>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TOKEN)
{
scaleMatrix<QuantizeMode::PER_TOKEN, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
scaleMatrix<QuantizeMode::PER_TENSOR, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
sync_check_cuda_error();
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
dim3 grid(1024);
dim3 block(CTA_SIZE);
if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
scaleMatrix<QuantizeMode::PER_CHANNEL, false>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TOKEN)
{
scaleMatrix<QuantizeMode::PER_TOKEN, false><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
scaleMatrix<QuantizeMode::PER_TENSOR, false>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
sync_check_cuda_error();
}
template <typename T_FAKE, typename T_OUT, typename T_IN>
__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel)
{
for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x)
{
T_FAKE tmp = (T_FAKE) (static_cast<float>(src[tid]));
dst[tid] = (T_OUT) (static_cast<float>(tmp));
}
}
template <typename T_FAKE, typename T_OUT, typename T_IN>
void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream)
{
fakeQuantize<T_FAKE><<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel);
sync_check_cuda_error();
}
template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>(
float* dst, float const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<float, float, __nv_fp8_e4m3>(
float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>(
half* dst, half const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>(
__nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<float, half, float>(
half* dst, float const* src, const int64_t numel, cudaStream_t stream);
__device__ float atomicMaxExtd(float* address, float val)
{
assert(val >= 0);
unsigned int* address_as_u = reinterpret_cast<unsigned int*>(address);
unsigned int old = atomicMax(address_as_u, __float_as_uint(val));
return __uint_as_float(old);
}
template <typename T>
inline __device__ T atomicMaxExtdV2(T* address, T val)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
static_assert(std::is_same_v<T, half> | std::is_same_v<T, __nv_bfloat16>, "T needs to be either half or bfloat16");
// The address in 64 bits.
uint64_t address_u64 = reinterpret_cast<uint64_t const&>(address);
// Pack the input value into 32 bits.
union
{
T v[2];
uint16_t u[2];
} old, tmp = {};
int const loc = (address_u64 & 0x2) >> 1;
tmp.v[loc] = val;
// 4B aligned pointer.
auto aligned_address = reinterpret_cast<T*>(address_u64 & ~0x3ull);
if constexpr (std::is_same_v<T, half>)
{
asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};"
: "=h"(old.u[0]), "=h"(old.u[1])
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
}
if constexpr (std::is_same_v<T, __nv_bfloat16>)
{
asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};"
: "=h"(old.u[0]), "=h"(old.u[1])
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
}
// Return the correct half.
return old.v[loc];
#endif
}
__device__ half atomicMaxExtd(half* address, half val)
{
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
unsigned short int old = *address_as_u, assumed;
while (val > __ushort_as_half(old))
{
assumed = old;
old = atomicCAS(address_as_u, assumed, __half_as_ushort(val));
}
return __ushort_as_half(old);
}
__device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
unsigned short int old = *address_as_u, assumed;
while (val > __ushort_as_bfloat16(old))
{
assumed = old;
old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val));
}
return __ushort_as_bfloat16(old);
#else
assert(0);
asm volatile("brkpt;\n" ::);
return __nv_bfloat16(0);
#endif
}
template <QuantizeMode QUANTIZE_MODE, typename T_S, typename T_W>
__global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n)
{
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
{
for (int64_t col = threadIdx.x; col < n; col += blockDim.x)
{
float max = 0.f;
for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n)
{
auto val = fabs(static_cast<float>(weights[i]));
max = max > val ? max : val;
}
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if constexpr (std::is_same_v<T_S, float>)
{
atomicMaxExtd(quant_ptr + col, scale);
}
else
{
auto const address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col);
if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0))
atomicMaxExtd(quant_ptr + col, scale);
else
atomicMaxExtdV2(quant_ptr + col, scale);
}
#else // Vector atomics require __CUDA_ARCH__ >= 900
atomicMaxExtd(quant_ptr + col, scale);
#endif
}
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
{
auto const nrows = size / n;
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
{
float max = 0.f;
for (int64_t i = threadIdx.x; i < n; i += blockDim.x)
{
auto val = fabs(static_cast<float>(weights[row * n + i]));
max = max > val ? max : val;
}
max = blockReduceMax<float>(max);
if (threadIdx.x == 0)
{
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
quant_ptr[row] = scale;
}
}
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
{
float max = 0.f;
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x)
{
auto val = fabs(static_cast<float>(weights[i]));
max = max > val ? max : val;
}
max = blockReduceMax<float>(max);
if (threadIdx.x == 0)
{
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
atomicMaxExtd(quant_ptr, scale);
}
}
}
template <typename T_S, typename T_W>
void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
if (quantize_mode == QuantizeMode::PER_TOKEN)
{
dim3 block(CTA_SIZE);
dim3 grid(numel / lda);
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
dim3 block(CTA_SIZE);
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
dim3 block(1024);
dim3 grid(1024);
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
sync_check_cuda_error();
}
#define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \
template void invokeComputeFP8QuantizeScale<type_scale, type_in>(type_scale * input_scale, type_in const* weights, \
int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float);
#ifdef ENABLE_BF16
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16);
#endif
template <typename T_OUT, typename T_S, typename T_IN>
__global__ void dynamicQuantizeMatrixPerToken(
T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda)
{
extern __shared__ __align__(sizeof(float)) char _shmem[];
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
auto const nrows = numel / lda;
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
{
float max = 0.f;
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
{
auto const in = input[row * lda + i];
shmem[i] = in;
auto val = fabs(static_cast<float>(in));
max = max > val ? max : val;
}
max = blockAllReduceMax<float>(max); // __syncthreads() called so we can read shmem
auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
{
// true means we are quantizing
output[row * lda + i] = (T_OUT) scale<true>(static_cast<float>(shmem[i]), static_cast<float>(s));
}
if (threadIdx.x == 0)
{
quant_ptr[row] = s;
}
}
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel,
const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream)
{
if (quantize_mode == QuantizeMode::PER_TOKEN)
{
dim3 grid(numel / lda);
bool use_shmem = true;
auto const shmem_size = lda * sizeof(T_IN);
if (shmem_size >= (48 << 10))
{
cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken<T_OUT, T_S, T_IN>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
use_shmem = ret == cudaSuccess;
}
if (use_shmem)
{
// ensure the threadblock is as large as possible to increase occupancy
dim3 block(std::min((lda + 31) / 32 * 32, static_cast<int64_t>(1024)));
dynamicQuantizeMatrixPerToken<<<grid, block, shmem_size, stream>>>(output, quant_ptr, input, numel, lda);
}
else
{
dim3 block(CTA_SIZE);
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
}
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
dim3 block(CTA_SIZE);
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
dim3 block(1024);
dim3 grid(1024);
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
sync_check_cuda_error();
}
#define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \
template void invokeQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream); \
template void invokeDequantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream); \
template void invokeComputeScalesAndQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream);
#ifdef ENABLE_FP8
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half);
DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3);
DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3);
DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3);
#ifdef ENABLE_BF16
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, __nv_bfloat16);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, float, __nv_fp8_e4m3);
#endif
#endif
#endif // ENABLE_FP8
#ifdef ENABLE_FP8
inline __device__ __nv_bfloat16 max_abs_op(bf16_4_t v) {
return cuda_max(cuda_max<__nv_bfloat16>(cuda_abs(v.x)), cuda_max<__nv_bfloat16>(cuda_abs(v.y)));
}
inline __device__ __nv_bfloat16 max_abs_op(bf16_8_t v) {
return cuda_max<__nv_bfloat16>(max_abs_op(bf16_4_t{v.x, v.y}), max_abs_op(bf16_4_t{v.z, v.w}));
}
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 v, __nv_bfloat16 scale) {
return bf16hmul2(v, bf162bf162(scale));
}
inline __device__ bf16_4_t mul(bf16_4_t v, __nv_bfloat16 scale) {
bf16_4_t n;
n.x = mul(v.x, scale);
n.y = mul(v.y, scale);
return n;
}
inline __device__ bf16_8_t mul(bf16_8_t v, __nv_bfloat16 scale) {
bf16_8_t n;
n.x = mul(v.x, scale);
n.y = mul(v.y, scale);
n.z = mul(v.z, scale);
n.w = mul(v.w, scale);
return n;
}
inline __device__ void convert_to_fp8(fp8_4_t* v, const bf16_4_t u) {
v[0] = fp8_4_t(u.x, u.y);
}
inline __device__ void convert_to_fp8(fp8_8_t* v, const bf16_8_t u) {
v[0].x = fp8_2_t(u.x);
v[0].y = fp8_2_t(u.y);
v[0].z = fp8_2_t(u.z);
v[0].w = fp8_2_t(u.w);
}
template<typename T_S, bool COL_MAJOR_SCALE, int ELEM_PER_THREAD>
__global__ void computeFP8Quantize128Kernel(__nv_fp8_e4m3* fp8_output,
T_S* quant_ptr,
const __nv_bfloat16* weights,
const int64_t dim0,
const int64_t dim1,
const int64_t size) {
const int64_t global_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
using InputElem = typename packed_type<__nv_bfloat16, ELEM_PER_THREAD>::type;
using OutputElem = typename packed_type<__nv_fp8_e4m3, ELEM_PER_THREAD>::type;
auto weights_vec = reinterpret_cast<InputElem const*>(weights);
auto output_vec = reinterpret_cast<OutputElem *>(fp8_output);
if (global_idx * ELEM_PER_THREAD >= size) {
return;
}
auto w8 = weights_vec[global_idx];
float scale = cuda_max((float)1e-4, (float)max_abs_op(w8));
static constexpr int THREADS_PER_ROW = 128 / ELEM_PER_THREAD;
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
scale = max(scale, __shfl_xor_sync(0xFFFFFFFF, scale, mask, THREADS_PER_ROW));
}
scale = scale / FP8_E4M3_MAX;
w8 = mul(w8, (__nv_bfloat16)(1 / scale));
convert_to_fp8(output_vec + global_idx, w8);
if (threadIdx.x % THREADS_PER_ROW == 0) {
if constexpr (COL_MAJOR_SCALE) {
const int64_t now_idx = global_idx / THREADS_PER_ROW;
const int64_t row_idx = now_idx / dim1;
const int64_t col_idx = now_idx % dim1;
quant_ptr[col_idx * dim0 + row_idx] = scale;
} else {
quant_ptr[global_idx / THREADS_PER_ROW] = scale;
}
}
}
void invokeComputeFP8Quantize128(__nv_fp8_e4m3* fp8_output,
float* quant_ptr,
const __nv_bfloat16* weights,
const int64_t dim0,
const int64_t dim1,
const int64_t size,
bool col_major_scale,
cudaStream_t stream) {
RTP_LLM_CHECK(dim1 % 128 == 0);
static constexpr int ELEM_PER_THREAD = 8;
const int num_per_grid = CTA_SIZE / (128 / ELEM_PER_THREAD);
dim3 grid((size / 128 + num_per_grid - 1) / num_per_grid);
dim3 block(CTA_SIZE);
if (col_major_scale) {
computeFP8Quantize128Kernel<float, true, ELEM_PER_THREAD><<<grid, block, 0, stream>>>(fp8_output, quant_ptr, weights, dim0, dim1 / 128, size);
} else {
computeFP8Quantize128Kernel<float, false, ELEM_PER_THREAD><<<grid, block, 0, stream>>>(fp8_output, quant_ptr, weights, dim0, dim1, size);
}
}
#endif // ENABLE_FP8
} // namespace common
} // namespace tensorrt_llm