maga_transformer/cpp/kernels/rocm/layernorm_kernels.cu (1,124 lines of code) (raw):
#include "maga_transformer/cpp/cuda/cuda_type_utils.cuh"
#include "maga_transformer/cpp/kernels/rocm/layernorm_kernels.h"
#include "maga_transformer/cpp/cuda/reduce_kernel_utils.cuh"
#include "maga_transformer/cpp/rocm/hip_utils.h"
namespace rtp_llm {
using namespace rocm;
__device__ __forceinline__ int64_t loadOffset(int head_num, int size_per_head) {
// [[q_head_1],[q_head_2]...[k_head_1],[k_head_2]...[v_head_1],[v_head_2]...]
int head_id = blockIdx.y;
int batch_id = blockIdx.x;
int offset = batch_id * head_num * size_per_head + size_per_head * head_id;
return offset;
}
__device__ __forceinline__ int64_t loadOffsetStrided(const int stride, const int n_elems) {
return blockIdx.x * stride / n_elems;
}
template<typename T>
__global__ void
qkLayerNorm(T* __restrict qkv, const T* __restrict gamma, const float layernorm_eps, int head_num, int size_per_head) {
constexpr auto num_elems_T = num_elems<T>::value;
constexpr size_t warp_size = 32;
const int vec_size_per_head = size_per_head / num_elems_T;
const int n_elems = vec_size_per_head / warp_size;
using float_packed_t = typename packed_as<float, num_elems_T>::type;
const int tid = threadIdx.x;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
float local_sum = 0.0f;
for (int i = 0; i < n_elems; i++) {
auto index = loadOffset(head_num, vec_size_per_head) + tid * n_elems + i;
auto val_f = cuda_cast<float_packed_t>(ldg(&qkv[index]));
local_sum += cuda_sum<float>(val_f);
}
mean = warpReduceSum(local_sum);
if (threadIdx.x == 0) {
s_mean = mean / size_per_head;
}
__syncthreads();
float local_var_sum = 0.0f;
for (int i = 0; i < n_elems; i++) {
auto index = loadOffset(head_num, vec_size_per_head) + tid * n_elems + i;
auto val_f = cuda_cast<float_packed_t>(ldg(&qkv[index]));
auto diff = val_f - s_mean;
local_var_sum += cuda_sum<float>(diff * diff);
}
variance = warpReduceSum(local_var_sum);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / size_per_head + layernorm_eps);
}
__syncthreads();
for (int i = 0; i < n_elems; i++) {
auto index = loadOffset(head_num, vec_size_per_head) + tid * n_elems + i;
auto gamma_index = blockIdx.y * vec_size_per_head + tid * n_elems + i;
auto val_f = cuda_cast<float_packed_t>(ldg(&qkv[index]));
auto val_gamma = cuda_cast<float_packed_t>(gamma[gamma_index]);
qkv[index] = cuda_cast<T>((val_f - s_mean) * s_variance * val_gamma);
}
}
template<typename T, bool IS_BIAS>
__global__ void layerNormWithStride(T* __restrict data,
const T* __restrict gamma,
const T* __restrict beta,
const float layernorm_eps,
const int n, // 总特征维度
const int norm_size, // 归一化窗口大小
const int stride)
{
constexpr auto num_elems_T = num_elems<T>::value; // 向量化元素数
constexpr size_t warp_size = 32;
const int n_elems = norm_size / num_elems_T / warp_size;
using float_packed_t = typename packed_as<float, num_elems_T>::type;
const int tid = threadIdx.x;
const int sample_idx = blockIdx.x / (n / norm_size); // 样本索引
const int head_idx = blockIdx.x % (n / norm_size); // 头/窗口索引
__shared__ float s_mean;
__shared__ float s_variance;
// 计算当前窗口的起始位置
T* sample_start = data + sample_idx * (stride / num_elems_T);
T* head_start = sample_start + head_idx * (norm_size / num_elems_T);
// Stage 1: 计算均值
float local_sum = 0.0f;
#pragma unroll
for (int i = 0; i < n_elems; i++) {
int elem_idx = i * warp_size + tid;
auto val_f = cuda_cast<float_packed_t>(ldg(&head_start[elem_idx]));
local_sum += cuda_sum<float>(val_f);
}
float mean = warpReduceSum(local_sum);
if (tid == 0) {
s_mean = mean / norm_size;
}
__syncthreads();
float local_var_sum = 0.0f;
#pragma unroll
for (int i = 0; i < n_elems; i++) {
int elem_idx = i * warp_size + tid;
auto val_f = cuda_cast<float_packed_t>(ldg(&head_start[elem_idx]));
auto diff = val_f - s_mean;
local_var_sum += cuda_sum<float>(diff * diff);
}
float variance = warpReduceSum(local_var_sum);
if (tid == 0) {
s_variance = rsqrtf(variance / norm_size + layernorm_eps);
}
__syncthreads();
#pragma unroll
for (int i = 0; i < n_elems; i++) {
int elem_idx = i * warp_size + tid;
auto val_f = cuda_cast<float_packed_t>(ldg(&head_start[elem_idx]));
auto gamma_val = cuda_cast<float_packed_t>(gamma[elem_idx]);
if (IS_BIAS) {
auto beta_val = cuda_cast<float_packed_t>(beta[elem_idx]);
val_f = (val_f - s_mean) * s_variance * gamma_val + beta_val;
} else {
val_f = (val_f - s_mean) * s_variance * gamma_val;
}
head_start[elem_idx] = cuda_cast<T>(val_f);
}
}
template<typename T>
void invokeQkLayerNorm(T* __restrict qkv,
const T* __restrict gamma,
const float layernorm_eps,
const int tokens,
const int head_num,
const int head_num_kv,
const int size_per_head,
cudaStream_t stream) {
constexpr size_t vec_size = 2;
constexpr size_t warp_size = 32;
if (size_per_head % warp_size != 0) {
throw std::invalid_argument("not supported size_per_head: " + std::to_string(size_per_head));
}
dim3 grid(tokens, head_num + head_num_kv);
dim3 block(warp_size);
int total_head_num = head_num + 2 * head_num_kv;
using Tp = typename packed_as<T, vec_size>::type;
qkLayerNorm<Tp><<<grid, block, 0, stream>>>(
reinterpret_cast<Tp*>(qkv), reinterpret_cast<const Tp*>(gamma), layernorm_eps, total_head_num, size_per_head);
}
template<typename T>
void invokeLayerNormWithStride(T* __restrict data,
const T* __restrict gamma,
const T* __restrict beta,
const float layernorm_eps,
const int m,
const int n,
const int norm_size,
const int stride,
const int offset,
cudaStream_t stream) {
constexpr size_t vec_size = 2;
constexpr size_t warp_size = 32;
data = data + offset;
// 参数校验
if (n % norm_size != 0) {
throw std::invalid_argument("n must be divisible by norm_size");
}
if (norm_size % (warp_size * vec_size) != 0) {
throw std::invalid_argument("norm_size must be multiple of " +
std::to_string(warp_size * vec_size));
}
const int num_heads = n / norm_size;
dim3 grid(m * num_heads); // 每个block处理一个样本的一个头
dim3 block(warp_size);
using Tp = typename packed_as<T, vec_size>::type;
bool is_bias = beta != nullptr;
if (is_bias) {
layerNormWithStride<Tp, true><<<grid, block, 0, stream>>>(reinterpret_cast<Tp*>(data), reinterpret_cast<const Tp*>(gamma),
reinterpret_cast<const Tp*>(beta),
layernorm_eps, n, norm_size, stride);
}
else {
layerNormWithStride<Tp, false><<<grid, block, 0, stream>>>(reinterpret_cast<Tp*>(data), reinterpret_cast<const Tp*>(gamma),
reinterpret_cast<const Tp*>(beta),
layernorm_eps, n, norm_size, stride);
}
}
#define INSTANTIATE_QK_LAYERNORM(T) \
template void invokeQkLayerNorm(T* __restrict qkv, \
const T* __restrict gamma, \
const float layernorm_eps, \
const int tokens, \
const int head_num, \
const int head_num_kv, \
const int size_per_head, \
cudaStream_t stream)
INSTANTIATE_QK_LAYERNORM(float);
INSTANTIATE_QK_LAYERNORM(half);
#ifdef ENABLE_BF16
INSTANTIATE_QK_LAYERNORM(__nv_bfloat16);
#endif
#undef INSTANTIATE_QK_LAYERNORM
#define INSTANTIATE_STRIDED_LAYERNORM(T) \
template void invokeLayerNormWithStride(T* __restrict data, \
const T* __restrict gamma, \
const T* __restrict beta, \
const float layernorm_eps, \
const int m, \
const int n, \
const int norm_size, \
const int stride, \
const int offset, \
cudaStream_t stream);
INSTANTIATE_STRIDED_LAYERNORM(float);
INSTANTIATE_STRIDED_LAYERNORM(half);
#ifdef ENABLE_BF16
INSTANTIATE_STRIDED_LAYERNORM(__nv_bfloat16);
#endif
#undef INSTANTIATE_STRIDED_LAYERNORM
template<typename Tf, typename T, bool IS_BETA>
__inline__ __device__ Tf
compute_layernorm(Tf val, float s_mean, float s_variance, const T* gamma, const T* beta, int i) {
Tf ret = (val - s_mean) * s_variance * cuda_cast<Tf>(gamma[i]);
if (IS_BETA) {
ret = ret + cuda_cast<Tf>(beta[i]);
}
return ret;
}
/* Computes the layernorm https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
* normed_output <- ( (input - E[input]) / Sqrt(Var[input] + eps) ) * gamma + beta
* input is [tokens, hidden_dim]. Mean and Variance are per-row (i.e. per-token)
*
* One CTA handles one row.
*
* with USE_DIFF_OF_SQUARES set to false:
* First pass (loop) computes the mean.
* Second computes the variance via Var[x] = E[(x - E[x])²].
* Third pass computes and writes normed_output
*
* with USE_DIFF_OF_SQUARES set to true (may be faster but less accurate):
* First pass (loop) computes the mean and variance via Var[x] = E[x²] - E[x]²
* Second pass computes and writes normed_output
*
* use_shmem controls if we cache input values into shared memory
*
* Optional: with dynamic scaling, the last pass doesn't write immediately but finds the
* amax per row. A final pass scales to int8 accordingly, and writes output to
* normed_output_quant.
*/
template<typename T,
bool IS_OUTPUT,
bool IS_BIAS,
bool RESIDUAL,
bool IS_BETA,
bool RETURN_NORMED_OUTPUT,
bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm(T* output,
T* normed_output,
const T* input,
const T* bias,
const T* residual,
const T* gamma,
const T* beta,
const float eps,
int tokens,
int hidden_dim,
const float* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token,
int8_t* normed_output_quant) {
constexpr auto num_elems_T = num_elems<T>::value;
using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type;
using Int32_Packed_T = typename packed_as<int32_t, num_elems<T>::value>::type;
using float_packed_t = typename packed_as<float, num_elems_T>::type;
using T_scalar = typename packed_as<T, 1>::type;
extern __shared__ __align__(sizeof(float)) char _shmem[];
T* shmem = reinterpret_cast<T*>(_shmem);
__shared__ float s_mean;
__shared__ float s_variance;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
float mean = 0.0f;
float variance = 0.0f;
float local_sum = 0.0f;
float local_var_sum = 0.0f;
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant =
cuda_cast<float_packed_t>(with_per_tensor_scaling ? *scale_orig_quant_per_tensor : 0.0f);
T_scalar amax(1e-6f);
const int n_elems = hidden_dim / num_elems_T;
for (int i = tidx; i < n_elems; i += blockDim.x) {
// const T val = input[bidx * n_elems + i];
const int index = bidx * n_elems + i;
T val = input[index];
// const T val = input[index];
if (IS_BIAS) {
val = add(val, ldg(&bias[i]));
}
if (RESIDUAL) {
val = add(val, ldg(&residual[index]));
}
if (IS_OUTPUT && !RETURN_NORMED_OUTPUT) {
output[index] = val;
}
shmem[i] = val;
const float_packed_t val_f = cuda_cast<float_packed_t>(val);
local_sum += cuda_sum<float>(val_f);
if (USE_DIFF_OF_SQUARES) {
local_var_sum += cuda_sum<float>(val_f * val_f);
}
}
if (USE_DIFF_OF_SQUARES) {
float packed[2] = {local_sum, local_var_sum};
blockReduceSumV2<float, 2>(packed);
mean = packed[0];
variance = packed[1];
} else {
mean = blockReduceSum(local_sum);
}
if (threadIdx.x == 0) {
mean = mean / hidden_dim;
s_mean = mean;
if (USE_DIFF_OF_SQUARES) {
variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
s_variance = rsqrtf(variance + eps);
}
}
__syncthreads();
if (!USE_DIFF_OF_SQUARES) {
for (int i = tidx; i < n_elems; i += blockDim.x) {
const T val = shmem[i];
float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
local_var_sum += cuda_sum<float>(diff * diff);
}
variance = blockReduceSum(local_var_sum);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_dim + eps);
}
__syncthreads();
}
for (int i = tidx; i < n_elems; i += blockDim.x) {
const int index = bidx * n_elems + i;
const float_packed_t val_f = cuda_cast<float_packed_t>(shmem[i]);
const T val =
cuda_cast<T>(compute_layernorm<float_packed_t, T, IS_BETA>(val_f, s_mean, s_variance, gamma, beta, i));
if (RETURN_NORMED_OUTPUT && IS_OUTPUT) {
output[index] = val;
}
if (with_per_token_scaling) {
amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
shmem[i] = val;
} else if (with_per_tensor_scaling) {
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index] =
cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
} else {
normed_output[index] = val;
}
}
if (with_per_token_scaling) {
float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax));
const float dynamic_per_token_scale = 127.f / abs_max_f;
for (int i = tidx; i < n_elems; i += blockDim.x) {
const int index = bidx * n_elems + i;
float_packed_t val_f = cuda_cast<float_packed_t>(shmem[i]);
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index] =
cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
}
if (tidx == 0) {
scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
}
}
}
template<typename T,
bool IS_OUTPUT,
bool IS_BIAS,
bool RESIDUAL,
bool IS_BETA,
bool RETURN_NORMED_OUTPUT,
bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNormNoShmem(T* output,
T* normed_output,
const T* input,
const T* bias,
const T* residual,
const T* gamma,
const T* beta,
const float eps,
int tokens,
int hidden_dim,
const float* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token,
int8_t* normed_output_quant) {
constexpr auto num_elems_T = num_elems<T>::value;
using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type;
using Int32_Packed_T = typename packed_as<int32_t, num_elems<T>::value>::type;
using float_packed_t = typename packed_as<float, num_elems_T>::type;
using T_scalar = typename packed_as<T, 1>::type;
T shmem_val;
__shared__ float s_mean;
__shared__ float s_variance;
const int n_elems = hidden_dim / num_elems_T;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int glb_index = bidx * n_elems + tidx;
float mean = 0.0f;
float variance = 0.0f;
float local_sum = 0.0f;
float local_var_sum = 0.0f;
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant =
cuda_cast<float_packed_t>(with_per_tensor_scaling ? *scale_orig_quant_per_tensor : 0.0f);
T_scalar amax(1e-6f);
if (tidx < n_elems) {
T val = input[glb_index];
if (IS_BIAS) {
val = add(val, ldg(&bias[tidx]));
}
if (RESIDUAL) {
val = add(val, ldg(&residual[glb_index]));
}
if (IS_OUTPUT && !RETURN_NORMED_OUTPUT) {
output[glb_index] = val;
}
// shmem[i] = val;
shmem_val = val;
const float_packed_t val_f = cuda_cast<float_packed_t>(val);
local_sum += cuda_sum<float>(val_f);
if (USE_DIFF_OF_SQUARES) {
local_var_sum += cuda_sum<float>(val_f * val_f);
}
}
if (USE_DIFF_OF_SQUARES) {
float packed[2] = {local_sum, local_var_sum};
blockReduceSumV2<float, 2>(packed);
mean = packed[0];
variance = packed[1];
} else {
mean = blockReduceSum(local_sum);
}
if (threadIdx.x == 0) {
mean = mean / hidden_dim;
s_mean = mean;
if (USE_DIFF_OF_SQUARES) {
variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
s_variance = rsqrtf(variance + eps);
}
}
__syncthreads();
if (!USE_DIFF_OF_SQUARES) {
if (tidx < n_elems) {
// const T val = shmem[i];
const T val = shmem_val;
float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
local_var_sum += cuda_sum<float>(diff * diff);
}
variance = blockReduceSum(local_var_sum);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_dim + eps);
}
__syncthreads();
}
if (tidx < n_elems) {
// const float_packed_t val_f = cuda_cast<float_packed_t>(shmem[i]);
const float_packed_t val_f = cuda_cast<float_packed_t>(shmem_val);
const T val =
cuda_cast<T>(compute_layernorm<float_packed_t, T, IS_BETA>(val_f, s_mean, s_variance, gamma, beta, tidx));
if (RETURN_NORMED_OUTPUT && IS_OUTPUT) {
output[glb_index] = val;
}
if (with_per_token_scaling) {
amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
// shmem[i] = val;
shmem_val = val;
} else if (with_per_tensor_scaling) {
reinterpret_cast<int8_packed_t*>(normed_output_quant)[glb_index] =
cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
} else {
normed_output[glb_index] = val;
}
}
if (with_per_token_scaling) {
float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax));
const float dynamic_per_token_scale = 127.f / abs_max_f;
if (tidx < n_elems) {
// float_packed_t val_f = cuda_cast<float_packed_t>(shmem[i]);
float_packed_t val_f = cuda_cast<float_packed_t>(shmem_val);
reinterpret_cast<int8_packed_t*>(normed_output_quant)[glb_index] =
cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
}
if (tidx == 0) {
scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
}
}
}
template<typename T,
bool IS_OUTPUT,
bool IS_BIAS,
bool RESIDUAL,
bool IS_BETA,
bool RETURN_NORMED_OUTPUT,
bool USE_DIFF_OF_SQUARES>
void dispatch_layernorm_type_square_method(T* output,
T* normed_output,
const T* input,
const T* bias,
const T* residual,
const T* gamma,
const T* beta,
const float eps,
int tokens,
int hidden_dim,
const float* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token,
int8_t* normed_output_quant,
const dim3 grid,
const dim3 block,
const size_t shmem_size,
cudaStream_t stream,
int vec_size = 1) {
if (shmem_size == 0) {
dim3 _block(block);
_block.x /= vec_size;
generalLayerNormNoShmem<T, IS_OUTPUT, IS_BIAS, RESIDUAL, IS_BETA, RETURN_NORMED_OUTPUT, USE_DIFF_OF_SQUARES>
<<<grid, _block, 0, stream>>>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant);
} else {
generalLayerNorm<T, IS_OUTPUT, IS_BIAS, RESIDUAL, IS_BETA, RETURN_NORMED_OUTPUT, USE_DIFF_OF_SQUARES>
<<<grid, block, shmem_size, stream>>>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant);
}
}
template<typename T, bool IS_OUTPUT, bool IS_BIAS, bool RESIDUAL, bool IS_BETA, bool RETURN_NORMED_OUTPUT>
void dispatch_layernorm_return_normed(T* output,
T* normed_output,
const T* input,
const T* bias,
const T* residual,
const T* gamma,
const T* beta,
const float eps,
int tokens,
int hidden_dim,
const float* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token,
int8_t* normed_output_quant,
const dim3 grid,
const dim3 block,
const size_t shmem_size,
cudaStream_t stream,
bool use_diff_of_squares,
int vec_size = 1) {
if (use_diff_of_squares) {
dispatch_layernorm_type_square_method<T, IS_OUTPUT, IS_BIAS, RESIDUAL, IS_BETA, RETURN_NORMED_OUTPUT, true>(
output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
vec_size);
} else {
dispatch_layernorm_type_square_method<T, IS_OUTPUT, IS_BIAS, RESIDUAL, IS_BETA, RETURN_NORMED_OUTPUT, false>(
output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
vec_size);
}
}
template<typename T, bool IS_OUTPUT, bool IS_BIAS, bool RESIDUAL, bool IS_BETA>
void dispatch_layernorm_type(T* output,
T* normed_output,
const T* input,
const T* bias,
const T* residual,
const T* gamma,
const T* beta,
const float eps,
int tokens,
int hidden_dim,
const float* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token,
int8_t* normed_output_quant,
const dim3 grid,
const dim3 block,
const size_t shmem_size,
cudaStream_t stream,
bool use_diff_of_squares,
bool return_normed_output,
int vec_size = 1) {
if (return_normed_output) {
dispatch_layernorm_return_normed<T, IS_OUTPUT, IS_BIAS, RESIDUAL, IS_BETA, true>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
vec_size);
} else {
dispatch_layernorm_return_normed<T, IS_OUTPUT, IS_BIAS, RESIDUAL, IS_BETA, false>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
vec_size);
}
}
template<typename T, bool IS_OUTPUT, bool IS_BIAS, bool RESIUDAL>
void dispatch_layernorm_beta(T* output,
T* normed_output,
const T* input,
const T* bias,
const T* residual,
const T* gamma,
const T* beta,
const float eps,
int tokens,
int hidden_dim,
const float* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token,
int8_t* normed_output_quant,
const dim3 grid,
const dim3 block,
const size_t shmem_size,
cudaStream_t stream,
bool use_diff_of_squares,
bool return_normed_output,
int vec_size = 1) {
if (beta != nullptr) {
dispatch_layernorm_type<T, IS_OUTPUT, IS_BIAS, RESIUDAL, true>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
return_normed_output,
vec_size);
} else {
dispatch_layernorm_type<T, IS_OUTPUT, IS_BIAS, RESIUDAL, false>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
return_normed_output,
vec_size);
}
}
template<typename T, bool IS_OUTPUT, bool IS_BIAS>
void dispatch_layernorm_residual(T* output,
T* normed_output,
const T* input,
const T* bias,
const T* residual,
const T* gamma,
const T* beta,
const float eps,
int tokens,
int hidden_dim,
const float* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token,
int8_t* normed_output_quant,
const dim3 grid,
const dim3 block,
const size_t shmem_size,
cudaStream_t stream,
bool use_diff_of_squares,
bool return_normed_output,
int vec_size = 1) {
if (residual != nullptr) {
dispatch_layernorm_beta<T, IS_OUTPUT, IS_BIAS, true>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
return_normed_output,
vec_size);
} else {
dispatch_layernorm_beta<T, IS_OUTPUT, IS_BIAS, false>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
return_normed_output,
vec_size);
}
}
template<typename T, bool IS_OUTPUT>
void dispatch_layernorm_bias(T* output,
T* normed_output,
const T* input,
const T* bias,
const T* residual,
const T* gamma,
const T* beta,
const float eps,
int tokens,
int hidden_dim,
const float* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token,
int8_t* normed_output_quant,
const dim3 grid,
const dim3 block,
const size_t shmem_size,
cudaStream_t stream,
bool use_diff_of_squares,
bool return_normed_output,
int vec_size = 1) {
if (bias != nullptr) {
dispatch_layernorm_residual<T, IS_OUTPUT, true>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
return_normed_output,
vec_size);
} else {
dispatch_layernorm_residual<T, IS_OUTPUT, false>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
return_normed_output,
vec_size);
}
}
template<typename T>
void dispatch_layernorm_output(T* output,
T* normed_output,
const T* input,
const T* bias,
const T* residual,
const T* gamma,
const T* beta,
const float eps,
int tokens,
int hidden_dim,
const float* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token,
int8_t* normed_output_quant,
const dim3 grid,
const dim3 block,
const size_t shmem_size,
cudaStream_t stream,
bool use_diff_of_squares,
bool is_output,
bool return_normed_output,
int vec_size = 1) {
if (is_output) {
dispatch_layernorm_bias<T, true>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
return_normed_output,
vec_size);
} else {
dispatch_layernorm_bias<T, false>(output,
normed_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale_orig_quant_per_tensor,
scale_orig_quant_per_token,
normed_output_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
return_normed_output,
vec_size);
}
}
template<typename T>
void invokeGeneralLayerNorm(T* out,
T* normed_output,
const T* input,
const T* gamma,
const T* beta,
const float eps,
const int tokens,
const int hidden_dim,
cudaStream_t stream,
bool use_diff_of_squares,
const float* scale,
float* dynamic_scale,
int8_t* out_quant,
bool return_normed_output) {
dim3 grid(tokens);
dim3 block(min(hidden_dim, 1024));
// Make sure block.x is multiple of 32 for warp shuffle to work
block.x = 32 * ((block.x + 31) / 32);
constexpr size_t vec_size = 2;
const size_t shmem_size = hidden_dim * sizeof(T);
const bool use_vec_type = (hidden_dim % vec_size == 0)
&& (std::is_same<T, half>::value
#ifdef ENABLE_BF16
|| std::is_same<T, __nv_bfloat16>::value
#endif
);
if (use_vec_type) {
using Tp = typename packed_as<T, vec_size>::type;
dispatch_layernorm_output(reinterpret_cast<Tp*>(out),
reinterpret_cast<Tp*>(normed_output),
reinterpret_cast<const Tp*>(input),
(const Tp*)nullptr,
(const Tp*)nullptr,
reinterpret_cast<const Tp*>(gamma),
reinterpret_cast<const Tp*>(beta),
eps,
tokens,
hidden_dim,
scale,
dynamic_scale,
out_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
out != nullptr,
return_normed_output);
} else {
dispatch_layernorm_output(out,
normed_output,
(const T*)input,
(const T*)nullptr,
(const T*)nullptr,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale,
dynamic_scale,
out_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
out != nullptr,
return_normed_output);
}
}
template<typename T>
void invokeGeneralAddBiasResidualLayerNorm(T* out,
T* norm_output,
const T* input,
const T* bias,
const T* residual,
const T* gamma,
const T* beta,
const float eps,
const int tokens,
const int hidden_dim,
cudaStream_t stream,
bool use_diff_of_squares,
const float* scale,
float* dynamic_scale,
int8_t* out_quant,
bool return_normed_output) {
dim3 grid(tokens);
dim3 block(min(hidden_dim, 1024));
// Make sure block.x is multiple of 32 for warp shuffle to work
block.x = 32 * ((block.x + 31) / 32);
constexpr size_t vec_size = 2;
const size_t shmem_size = hidden_dim > block.x ? (hidden_dim * sizeof(T)) : 0;
const bool use_vec_type = (hidden_dim % vec_size == 0)
&& (std::is_same<T, half>::value
#ifdef ENABLE_BF16
|| std::is_same<T, __nv_bfloat16>::value
#endif
);
if (use_vec_type) {
using Tp = typename packed_as<T, vec_size>::type;
dispatch_layernorm_output(reinterpret_cast<Tp*>(out),
reinterpret_cast<Tp*>(norm_output),
reinterpret_cast<const Tp*>(input),
reinterpret_cast<const Tp*>(bias),
reinterpret_cast<const Tp*>(residual),
reinterpret_cast<const Tp*>(gamma),
reinterpret_cast<const Tp*>(beta),
eps,
tokens,
hidden_dim,
scale,
dynamic_scale,
out_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
true,
return_normed_output,
vec_size);
} else {
dispatch_layernorm_output(out,
norm_output,
input,
bias,
residual,
gamma,
beta,
eps,
tokens,
hidden_dim,
scale,
dynamic_scale,
out_quant,
grid,
block,
shmem_size,
stream,
use_diff_of_squares,
true,
return_normed_output,
1);
}
}
#define INSTANTIATE_GENERAL_LAYERNORM(T) \
template void invokeGeneralLayerNorm(T* out, \
T* normed_output, \
const T* input, \
const T* gamma, \
const T* beta, \
const float eps, \
const int tokens, \
const int hidden_dim, \
cudaStream_t stream, \
bool use_diff_of_squares, \
const float* scale, \
float* dynamic_scale, \
int8_t* out_quant, \
bool return_normed_output);
INSTANTIATE_GENERAL_LAYERNORM(float);
INSTANTIATE_GENERAL_LAYERNORM(half);
#ifdef ENABLE_BF16
INSTANTIATE_GENERAL_LAYERNORM(__nv_bfloat16);
#endif
#define INSTANTIATE_GENERAL_ADD_BIAS_RESDIAUL_LAYERNORM(T) \
template void invokeGeneralAddBiasResidualLayerNorm(T* out, \
T* norm_output, \
const T* input, \
const T* bias, \
const T* residual, \
const T* gamma, \
const T* beta, \
const float eps, \
const int tokens, \
const int hidden_dim, \
cudaStream_t stream, \
bool use_diff_of_squares, \
const float* scale, \
float* dynamic_scale, \
int8_t* out_quant, \
bool return_normed_output);
INSTANTIATE_GENERAL_ADD_BIAS_RESDIAUL_LAYERNORM(float);
INSTANTIATE_GENERAL_ADD_BIAS_RESDIAUL_LAYERNORM(half);
#ifdef ENABLE_BF16
INSTANTIATE_GENERAL_ADD_BIAS_RESDIAUL_LAYERNORM(__nv_bfloat16);
#endif
} // namespace rtp_llm