maga_transformer/cpp/kernels/rmsnormKernels.cu (443 lines of code) (raw):
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "maga_transformer/cpp/cuda/cuda_type_utils.cuh"
#include "maga_transformer/cpp/cuda/reduce_kernel_utils.cuh"
#include "maga_transformer/cpp/kernels/rmsnormKernels.h"
namespace rtp_llm
{
#if USING_ROCM
using namespace rocm;
#endif
template <typename Tf, typename T, bool IS_BETA>
__inline__ __device__ Tf compute_rmsnorm(Tf val, float s_variance, const T* gamma, const T* beta, int i)
{
Tf ret = val * s_variance * cuda_cast<Tf>(gamma[i]);
if (IS_BETA)
{
ret = ret + cuda_cast<Tf>(beta[i]);
}
return ret;
}
/* Computes the rmsnorm https://pytorch.org/docs/stable/generated/torch.nn.rmsnorm.html
* normed_output <- ( input / Sqrt(E[input²] + eps) ) * gamma + beta
* input is [tokens, hidden_dim]. Mean and Variance are per-row (i.e. per-token)
*
* One CTA handles one row.
*
*
* use_shmem controls if we cache input values into shared memory
*
* Optional: with dynamic scaling, the last pass doesn't write immediately but finds the
* amax per row. A final pass scales to int8 accordingly, and writes output to
* normed_output_quant.
*/
template <typename T, bool IS_OUTPUT, bool IS_BIAS, bool RESIDUAL, bool IS_BETA, typename QUANT_OUT_T>
__global__ void generalRmsNorm(T* output, T* normed_output, const T* input, const T* bias, const T* residual1, const T* residual2,
const T* gamma, const T* beta, const float eps, int tokens, int hidden_dim,
const float* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, QUANT_OUT_T* normed_output_quant)
{
constexpr auto num_elems_T = num_elems<T>::value;
using quant_packed_t = typename packed_as<QUANT_OUT_T, num_elems_T>::type;
using Int32_Packed_T = typename packed_as<int32_t, num_elems<T>::value>::type;
using float_packed_t = typename packed_as<float, num_elems_T>::type;
using T_scalar = typename packed_as<T, 1>::type;
extern __shared__ __align__(sizeof(float)) char _shmem[];
T* shmem = reinterpret_cast<T*>(_shmem);
__shared__ float s_variance;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
float variance = 0.0f;
float local_var_sum = 0.0f;
const int n_elems = hidden_dim / num_elems_T;
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const bool has_residual2 = residual2 != nullptr;
const float_packed_t scale_orig_quant
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? *scale_orig_quant_per_tensor : 0.0f);
T_scalar amax = getAmax<QUANT_OUT_T>();
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const int index = bidx * n_elems + i;
T val = cuda_cast<T>(0.0f);
// const T val = input[index];
if (IS_BIAS)
{
val = add(val, ldg(&bias[i]));
}
if (RESIDUAL)
{
val = add(val, ldg(&residual1[index]));
if (has_residual2) {
val = add(val, ldg(&residual2[index]));
}
}
if (IS_OUTPUT)
{
T in_val = input[index];
val = add(val, in_val);
}
shmem[i] = val;
if (IS_OUTPUT)
{
output[index] = val;
}
const float_packed_t val_f = cuda_cast<float_packed_t>(val);
local_var_sum += cuda_sum<float>(val_f * val_f);
}
float packed[1] = {local_var_sum};
blockReduceSumV2<float, 1>(packed);
variance = packed[0];
if (threadIdx.x == 0)
{
variance = (variance / hidden_dim); // Var[x] = E[x²]
s_variance = rsqrtf(variance + eps);
}
__syncthreads();
const float scale_factor = getScaleFactor<QUANT_OUT_T>();
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const int index = bidx * n_elems + i;
const float_packed_t val_f = cuda_cast<float_packed_t>(shmem[i]);
const T val = cuda_cast<T>(compute_rmsnorm<float_packed_t, T, IS_BETA>(val_f, s_variance, gamma, beta, i));
if (with_per_token_scaling)
{
amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
shmem[i] = val;
}
else if (with_per_tensor_scaling)
{
reinterpret_cast<quant_packed_t*>(normed_output_quant)[index]
= cuda_cast<quant_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
}
else
{
normed_output[index] = val;
}
}
if (with_per_token_scaling)
{
float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax));
const float dynamic_per_token_scale = scale_factor / abs_max_f;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const int index = bidx * n_elems + i;
float_packed_t val_f = cuda_cast<float_packed_t>(shmem[i]);
reinterpret_cast<quant_packed_t*>(normed_output_quant)[index]
= cuda_cast<quant_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
}
if (tidx == 0)
{
scale_orig_quant_per_token[bidx] = abs_max_f / scale_factor;
}
}
}
template<typename T, bool IS_BIAS>
__global__ void rmsNormWithStride(T* __restrict output,
const int out_stride,
const T* __restrict input,
const int in_stride,
const T* __restrict gamma,
const T* __restrict bias,
const float eps,
const int n,
const int norm_size) {
constexpr auto num_elems_T = num_elems<T>::value;
using float_packed_t = typename packed_as<float, num_elems_T>::type;
constexpr int vec_size = num_elems<T>::value;
constexpr int warp_size = 32;
const int elements_per_thread = norm_size / (warp_size * vec_size);
const int sample_idx = blockIdx.x / (n / norm_size);
const int group_idx = blockIdx.x % (n / norm_size);
const T* group_start = input + sample_idx * (in_stride / vec_size) + group_idx * (norm_size / vec_size);
T* dest_start = output + sample_idx * (out_stride / vec_size) + group_idx * (norm_size / vec_size);
__shared__ float smem_scale;
float square_sum = 0.0f;
for (int i = 0; i < elements_per_thread; ++i) {
const int elem_idx = i * warp_size + threadIdx.x;
T packed_val = group_start[elem_idx];
auto val = cuda_cast<float_packed_t>(packed_val);
square_sum += cuda_sum<float>(val * val);
}
float variance = warpReduceSum(square_sum) / norm_size;
if (threadIdx.x == 0) {
smem_scale = rsqrtf(variance + eps);
}
__syncthreads();
for (int i = 0; i < elements_per_thread; ++i) {
const int elem_idx = i * warp_size + threadIdx.x;
T packed_val = group_start[elem_idx];
const float_packed_t val_f = cuda_cast<float_packed_t>(packed_val);
const T val = cuda_cast<T>(compute_rmsnorm<float_packed_t, T, IS_BIAS>(val_f, smem_scale, gamma, bias, elem_idx));
dest_start[elem_idx] = cuda_cast<T>(val);
}
}
template <typename T, bool IS_OUTPUT, bool IS_BIAS, bool RESIDUAL, bool IS_BETA, typename QUANT_OUT_T=int8_t>
void dispatch_rmsnorm_type_square_method(T* output, T* normed_output, const T* input, const T* bias, const T* residual1, const T* residual2,
const T* gamma, const T* beta, const float eps, int tokens, int hidden_dim,
const float* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, QUANT_OUT_T* normed_output_quant,
const dim3 grid, const dim3 block, const size_t shmem_size, cudaStream_t stream)
{
if (shmem_size >= (48 << 10))
{
#if USING_CUDA
cudaError_t ret = cudaFuncSetAttribute(generalRmsNorm<T, IS_OUTPUT, IS_BIAS, RESIDUAL, IS_BETA, QUANT_OUT_T>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
#endif
}
generalRmsNorm<T, IS_OUTPUT, IS_BIAS, RESIDUAL, IS_BETA, QUANT_OUT_T><<<grid, block, shmem_size, stream>>>(output, normed_output,
input, bias, residual1, residual2, gamma, beta, eps, tokens, hidden_dim, scale_orig_quant_per_tensor,
scale_orig_quant_per_token, normed_output_quant);
}
template <typename T, bool IS_OUTPUT, bool IS_BIAS, bool RESIDUAL, typename QUANT_OUT_T=int8_t>
void dispatch_rmsnorm_beta(T* output, T* normed_output, const T* input, const T* bias, const T* residual1, const T* residual2,
const T* gamma, const T* beta, const float eps, int tokens, int hidden_dim,
const float* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, QUANT_OUT_T* normed_output_quant,
const dim3 grid, const dim3 block, const size_t shmem_size, cudaStream_t stream)
{
if (beta != nullptr)
{
dispatch_rmsnorm_type_square_method<T, IS_OUTPUT, IS_BIAS, RESIDUAL, true, QUANT_OUT_T>(output, normed_output, input, bias,
residual1, residual2, gamma, beta, eps, tokens, hidden_dim, scale_orig_quant_per_tensor, scale_orig_quant_per_token,
normed_output_quant, grid, block, shmem_size, stream);
}
else
{
dispatch_rmsnorm_type_square_method<T, IS_OUTPUT, IS_BIAS, RESIDUAL, false, QUANT_OUT_T>(output, normed_output, input, bias,
residual1, residual2, gamma, beta, eps, tokens, hidden_dim, scale_orig_quant_per_tensor, scale_orig_quant_per_token,
normed_output_quant, grid, block, shmem_size, stream);
}
}
template <typename T, bool IS_OUTPUT, bool IS_BIAS, typename QUANT_OUT_T>
void dispatch_rmsnorm_residual(T* output, T* normed_output, const T* input, const T* bias, const T* residual1, const T* residual2,
const T* gamma, const T* beta, const float eps, int tokens, int hidden_dim,
const float* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, QUANT_OUT_T* normed_output_quant,
const dim3 grid, const dim3 block, const size_t shmem_size, cudaStream_t stream)
{
if (residual1 != nullptr)
{
dispatch_rmsnorm_beta<T, IS_OUTPUT, IS_BIAS, true, QUANT_OUT_T>(output, normed_output, input, bias, residual1, residual2, gamma, beta,
eps, tokens, hidden_dim, scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant, grid,
block, shmem_size, stream);
}
else
{
dispatch_rmsnorm_beta<T, IS_OUTPUT, IS_BIAS, false, QUANT_OUT_T>(output, normed_output, input, bias, residual1, residual2, gamma, beta,
eps, tokens, hidden_dim, scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant, grid,
block, shmem_size, stream);
}
}
template <typename T, bool IS_OUTPUT, typename QUANT_OUT_T>
void dispatch_rmsnorm_bias(T* output, T* normed_output, const T* input, const T* bias, const T* residual1, const T* residual2,
const T* gamma, const T* beta, const float eps, int tokens, int hidden_dim,
const float* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, QUANT_OUT_T* normed_output_quant,
const dim3 grid, const dim3 block, const size_t shmem_size, cudaStream_t stream)
{
if (bias != nullptr)
{
dispatch_rmsnorm_residual<T, IS_OUTPUT, true, QUANT_OUT_T>(output, normed_output, input, bias, residual1, residual2, gamma, beta, eps,
tokens, hidden_dim, scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant, grid,
block, shmem_size, stream);
}
else
{
dispatch_rmsnorm_residual<T, IS_OUTPUT, false, QUANT_OUT_T>(output, normed_output, input, bias, residual1, residual2, gamma, beta, eps,
tokens, hidden_dim, scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant, grid,
block, shmem_size, stream);
}
}
template <typename T, typename QUANT_OUT_T>
void dispatch_rmsnorm_output(T* output, T* normed_output, const T* input, const T* bias, const T* residual1, const T* residual2,
const T* gamma, const T* beta, const float eps, int tokens, int hidden_dim,
const float* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, QUANT_OUT_T* normed_output_quant,
const dim3 grid, const dim3 block, const size_t shmem_size, cudaStream_t stream, bool is_output)
{
if (is_output)
{
dispatch_rmsnorm_bias<T, true, QUANT_OUT_T>(output, normed_output, input, bias, residual1, residual2, gamma, beta, eps, tokens,
hidden_dim, scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant, grid, block,
shmem_size, stream);
}
else
{
dispatch_rmsnorm_bias<T, false, QUANT_OUT_T>(output, normed_output, input, bias, residual1, residual2, gamma, beta, eps, tokens,
hidden_dim, scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant, grid, block,
shmem_size, stream);
}
}
template <typename T, typename QUANT_OUT_T>
void invokeGeneralRmsNorm(T* out, const T* input, const T* gamma, const T* beta, const float eps, const int tokens,
const int hidden_dim, cudaStream_t stream, const float* scale, float* dynamic_scale, QUANT_OUT_T* normed_output_quant)
{
dim3 grid(tokens);
dim3 block(min(hidden_dim, 1024));
// Make sure block.x is multiple of 32 for warp shuffle to work
block.x = 32 * ((block.x + 31) / 32);
constexpr size_t vec_size = 2;
const size_t shmem_size = hidden_dim * sizeof(T);
const bool use_vec_type = (hidden_dim % vec_size == 0)
&& (std::is_same<T, half>::value
#ifdef ENABLE_BF16
|| std::is_same<T, __nv_bfloat16>::value
#endif
);
if (use_vec_type)
{
using Tp = typename packed_as<T, vec_size>::type;
dispatch_rmsnorm_output(reinterpret_cast<Tp*>(out), reinterpret_cast<Tp*>(out), reinterpret_cast<Tp*>(out),
(const Tp*) nullptr, reinterpret_cast<const Tp*>(input), (const Tp*) nullptr, reinterpret_cast<const Tp*>(gamma),
reinterpret_cast<const Tp*>(beta), eps, tokens, hidden_dim, scale, dynamic_scale, normed_output_quant, grid,
block, shmem_size, stream, false);
}
else
{
dispatch_rmsnorm_output(out, out, (const T*) out, (const T*) nullptr, input, (const T*) nullptr, gamma, beta, eps, tokens,
hidden_dim, scale, dynamic_scale, normed_output_quant, grid, block, shmem_size, stream, false);
}
}
template<typename T>
void invokeRmsNormWithStride(T* __restrict output,
const int out_stride,
const T* __restrict input,
const int in_stride,
const T* __restrict gamma,
const T* __restrict beta,
const float layernorm_eps,
const int m,
const int n,
const int norm_size,
cudaStream_t stream)
{
constexpr size_t vec_size = 2;
constexpr size_t warp_size = 32;
// 参数校验
if (n % norm_size != 0) {
throw std::invalid_argument("n must be divisible by norm_size");
}
if (norm_size % (warp_size * vec_size) != 0) {
throw std::invalid_argument("norm_size must be multiple of " +
std::to_string(warp_size * vec_size));
}
const int num_heads = n / norm_size;
dim3 grid(m * num_heads); // 每个block处理一个样本的一个头
dim3 block(warp_size);
using Tp = typename packed_as<T, vec_size>::type;
bool is_bias = beta != nullptr;
if (is_bias) {
rmsNormWithStride<Tp, true><<<grid, block, 0, stream>>>(reinterpret_cast<Tp*>(output),
out_stride,
reinterpret_cast<const Tp*>(input),
in_stride,
reinterpret_cast<const Tp*>(gamma),
reinterpret_cast<const Tp*>(beta),
layernorm_eps,
n,
norm_size);
} else {
rmsNormWithStride<Tp, false><<<grid, block, 0, stream>>>(reinterpret_cast<Tp*>(output),
out_stride,
reinterpret_cast<const Tp*>(input),
in_stride,
reinterpret_cast<const Tp*>(gamma),
reinterpret_cast<const Tp*>(beta),
layernorm_eps,
n,
norm_size);
}
}
template <typename T, typename QUANT_OUT_T>
void invokeAddBiasResidualRmsNorm(T* output, T* normed_output, const T* input, const T* bias, const T* residual, const T* residual2,
const T* gamma, const T* beta, const float eps, const int tokens, const int hidden_dim, cudaStream_t stream,
const float* scale, float* dynamic_scale, QUANT_OUT_T* normed_output_quant)
{
dim3 grid(tokens);
dim3 block(min(hidden_dim, 1024));
// Make sure block.x is multiple of 32 for warp shuffle to work
block.x = 32 * ((block.x + 31) / 32);
constexpr size_t vec_size = 2;
const size_t shmem_size = hidden_dim * sizeof(T);
const bool use_vec_type = (hidden_dim % vec_size == 0)
&& (std::is_same<T, half>::value
#ifdef ENABLE_BF16
|| std::is_same<T, __nv_bfloat16>::value
#endif
);
if (use_vec_type)
{
using Tp = typename packed_as<T, vec_size>::type;
dispatch_rmsnorm_output(reinterpret_cast<Tp*>(output), reinterpret_cast<Tp*>(normed_output),
reinterpret_cast<const Tp*>(input), reinterpret_cast<const Tp*>(bias),
reinterpret_cast<const Tp*>(residual), reinterpret_cast<const Tp*>(residual2),
reinterpret_cast<const Tp*>(gamma), reinterpret_cast<const Tp*>(beta),
eps, tokens, hidden_dim, scale, dynamic_scale, normed_output_quant, grid,
block, shmem_size, stream, true);
}
else
{
dispatch_rmsnorm_output(output, normed_output, input, bias, residual, residual2, gamma, beta, eps, tokens, hidden_dim,
scale, dynamic_scale, normed_output_quant, grid, block, shmem_size, stream, true);
}
}
#define INSTANTIATE_GENERAL_RMSNORM(T, QUANT_OUT_T) \
template void invokeGeneralRmsNorm(T* out, const T* input, const T* gamma, const T* beta, const float eps, \
const int tokens, const int hidden_dim, cudaStream_t stream, const float* scale, float* dynamic_scale, \
QUANT_OUT_T* normed_output_quant);
INSTANTIATE_GENERAL_RMSNORM(float, int8_t);
INSTANTIATE_GENERAL_RMSNORM(half, int8_t);
#ifdef ENABLE_BF16
INSTANTIATE_GENERAL_RMSNORM(__nv_bfloat16, int8_t);
#endif
#ifdef ENABLE_FP8
INSTANTIATE_GENERAL_RMSNORM(float, __nv_fp8_e4m3);
INSTANTIATE_GENERAL_RMSNORM(half, __nv_fp8_e4m3);
#ifdef ENABLE_BF16
INSTANTIATE_GENERAL_RMSNORM(__nv_bfloat16, __nv_fp8_e4m3);
#endif // ENABLE_BF16
#endif // ENABLE_FP8
#define INSTANTIATE_ADD_BIAS_RESL_RMSNORM(T, QUANT_OUT_T) \
template void invokeAddBiasResidualRmsNorm(T* output, T* normed_output, const T* input, const T* bias, \
const T* resiudal, const T* resiudal2, const T* gamma, const T* beta, const float eps, const int tokens, const int hidden_dim, \
cudaStream_t stream, const float* scale, float* dynamic_scale, QUANT_OUT_T* normed_output_quant);
INSTANTIATE_ADD_BIAS_RESL_RMSNORM(float, int8_t);
INSTANTIATE_ADD_BIAS_RESL_RMSNORM(half, int8_t);
#ifdef ENABLE_BF16
INSTANTIATE_ADD_BIAS_RESL_RMSNORM(__nv_bfloat16, int8_t);
#endif
#ifdef ENABLE_FP8
INSTANTIATE_ADD_BIAS_RESL_RMSNORM(float, __nv_fp8_e4m3);
INSTANTIATE_ADD_BIAS_RESL_RMSNORM(half, __nv_fp8_e4m3);
#ifdef ENABLE_BF16
INSTANTIATE_ADD_BIAS_RESL_RMSNORM(__nv_bfloat16, __nv_fp8_e4m3);
#endif // ENABLE_BF16
#endif // ENABLE_FP8
#define INSTANTIATE_STRIDED_RMSNORM(T) \
template void invokeRmsNormWithStride(T* __restrict output, \
const int out_stride, \
const T* __restrict input, \
const int in_stride, \
const T* __restrict gamma, \
const T* __restrict beta, \
const float layernorm_eps, \
const int m, \
const int n, \
const int norm_size, \
cudaStream_t stream);
INSTANTIATE_STRIDED_RMSNORM(float);
INSTANTIATE_STRIDED_RMSNORM(half);
#ifdef ENABLE_BF16
INSTANTIATE_STRIDED_RMSNORM(__nv_bfloat16);
#endif
#undef INSTANTIATE_STRIDED_RMSNORM
} // namespace rtp_llm