maga_transformer/cpp/kernels/alpha_layernorm_kernels.cu (766 lines of code) (raw):

#include "maga_transformer/cpp/cuda/cuda_type_utils.cuh" #include "maga_transformer/cpp/kernels/alpha_layernorm_kernels.h" #include "maga_transformer/cpp/cuda/reduce_kernel_utils.cuh" #if USING_ROCM #include "maga_transformer/cpp/rocm/hip_utils.h" #endif // wont't support new features namespace rtp_llm{ #if USING_ROCM using namespace rocm; #endif template<typename T, int N> __global__ void alphaAddBiasResidualPostLayerNorm( T* out, const T* input, const T* residual1, const T* bias, const T* gamma, const T* beta, T alpha, int m, int n) { __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f; float variance = 0.0f; float local_out_cache[N]; #pragma unroll N for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { float local_out = (float)(input[blockIdx.x * n + idx] + residual1[blockIdx.x * n + idx] * alpha + __ldg_func(&bias[idx])); mean += local_out; // save local_out to local_out_cache to save some recompute local_out_cache[i] = local_out; idx += blockDim.x; } mean = blockReduceSum<float>(mean); if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); #pragma unroll N for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { float local_out = local_out_cache[i]; variance += (local_out - s_mean) * (local_out - s_mean); idx += blockDim.x; } variance = blockReduceSum<float>(variance); if (threadIdx.x == 0) { s_variance = variance / n + 1e-6f; } __syncthreads(); #pragma unroll N for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { float local_out = local_out_cache[i]; out[blockIdx.x * n + idx] = (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg_func(&gamma[idx])) + (float)(__ldg_func(&beta[idx]))); idx += blockDim.x; } } template<typename T> __global__ void generalAlphaAddBiasResidualPostLayerNorm( T* out, const T* input, const T* residual1, const T* bias, const T* gamma, const T* beta, T alpha, int m, int n) { __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f; float variance = 0.0f; for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { float local_out = (float)(input[blockIdx.x * n + idx] + residual1[blockIdx.x * n + idx] * alpha + __ldg_func(&bias[idx])); mean += local_out; // save local_out to out to save some recompute out[blockIdx.x * n + idx] = local_out; } mean = blockReduceSum<float>(mean); if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { float local_out = out[blockIdx.x * n + idx]; variance += (local_out - s_mean) * (local_out - s_mean); } variance = blockReduceSum<float>(variance); if (threadIdx.x == 0) { s_variance = variance / n + 1e-6f; } __syncthreads(); for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { float local_out = out[blockIdx.x * n + idx]; out[blockIdx.x * n + idx] = (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg_func(&gamma[idx])) + (float)(__ldg_func(&beta[idx]))); } } template<> __global__ void generalAlphaAddBiasResidualPostLayerNorm(half* out, const half* input, const half* residual1, const half* bias, const half* gamma, const half* beta, half alpha, int m, int n) { __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f; float variance = 0.0f; half2 alpha2 = make_half2(alpha, alpha); half2* out_ptr = (half2*)out; const half2* input_ptr = (const half2*)input; const half2* residual_ptr = (const half2*)residual1; const half2* bias_ptr = (const half2*)bias; const half2* gamma_ptr = (const half2*)gamma; const half2* beta_ptr = (const half2*)beta; float local_out = 0.0f; for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { int id = blockIdx.x * n / 2 + idx; half2 tmp = __hadd2(__hadd2(input_ptr[id], __hmul2(residual_ptr[id], alpha2)), __ldg_func(&bias_ptr[idx])); float2 local_out_fp2 = __half22float2(tmp); local_out += local_out_fp2.x; local_out += local_out_fp2.y; // save tmp to out_ptr to save some recomputation out_ptr[id] = tmp; } mean = blockReduceSum<float>(local_out); if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { int id = blockIdx.x * n / 2 + idx; float2 local_out_fp2 = __half22float2(out_ptr[id]); variance += (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); } variance = blockReduceSum<float>(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / n + 1e-6f); } __syncthreads(); for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { int id = blockIdx.x * n / 2 + idx; float2 local_out_fp2 = __half22float2(out_ptr[id]); float2 gamma_val = __half22float2(__ldg_func(&gamma_ptr[idx])); float2 beta_val = __half22float2(__ldg_func(&beta_ptr[idx])); local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; out_ptr[id] = __float22half2_rn(local_out_fp2); } } template<typename T> __global__ void alphaAddBiasResidualPostLayerNormV2(T* out, const T* __restrict input, const T* __restrict residual1, const T* __restrict bias, const T* __restrict gamma, const T* __restrict beta, T alpha, int n) { const int ite = 4; const int tid = threadIdx.x; const int bid = blockIdx.x; __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f; float variance = 0.0f; float local_out[ite]; float sum = 0.0f; #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; int id = bid * n + col_id; local_out[i] = (float)(input[id] + __ldg_func(&residual1[id]) * alpha + __ldg_func(&bias[col_id])); sum += local_out[i]; } mean = blockReduceSum<float>(sum); if (tid == 0) { s_mean = mean / n; } __syncthreads(); float var = 0.0f; #pragma unroll for (int i = 0; i < ite; i++) { float diff = local_out[i] - s_mean; var += diff * diff; } variance = blockReduceSum<float>(var); if (tid == 0) { s_variance = rsqrtf(variance / n + 1e-6f); } __syncthreads(); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; int id = bid * n + col_id; out[id] = (T)((local_out[i] - s_mean) * s_variance * (float)__ldg_func(&gamma[col_id]) + (float)__ldg_func(&beta[col_id])); } } template<> __global__ void alphaAddBiasResidualPostLayerNormV2(half* out, const half* __restrict input, const half* __restrict residual1, const half* __restrict bias, const half* __restrict gamma, const half* __restrict beta, half alpha, int n) { const int ite = 4; const int tid = threadIdx.x; const int bid = blockIdx.x; __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f; float variance = 0.0f; half2 local_out_half2[ite]; half2 alpha2 = make_half2(alpha, alpha); half2* out_ptr = (half2*)out; const half2* input_ptr = (const half2*)input; const half2* residual_ptr = (const half2*)residual1; const half2* bias_ptr = (const half2*)bias; const half2* gamma_ptr = (const half2*)gamma; const half2* beta_ptr = (const half2*)beta; // float sum = 0.0f; half2 sum = __float2half2_rn(0.0f); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; int id = bid * n / 2 + col_id; local_out_half2[i] = input_ptr[id] + __ldg_func(&residual_ptr[id]) * alpha2 + __ldg_func(&bias_ptr[col_id]); sum += local_out_half2[i]; } mean = blockReduceSum<float>((float)(sum.x + sum.y)); if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); float var = 0.0f; half2 s_mean_2 = __float2half2_rn(s_mean); #pragma unroll for (int i = 0; i < ite; i++) { local_out_half2[i] = local_out_half2[i] - s_mean_2; float v1 = (float)local_out_half2[i].x; float v2 = (float)local_out_half2[i].y; var += v1 * v1 + v2 * v2; } variance = blockReduceSum<float>(var); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / n + 1e-6f); } __syncthreads(); half2 s_var_2 = __float2half2_rn(s_variance); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; int id = bid * n / 2 + col_id; out_ptr[id] = local_out_half2[i] * s_var_2 * __ldg_func(&gamma_ptr[col_id]) + __ldg_func(&beta_ptr[col_id]); } } template<typename T> void invokeAlphaAddBiasResidualLayerNorm(T* out, const T* input, const T* residual1, const T* bias, const T* gamma, const T* beta, T alpha, int m, int n, cudaStream_t stream) { dim3 grid(m); dim3 block(std::min(n, 1024)); if (n == 768 || n == 1024) { alphaAddBiasResidualPostLayerNormV2<T> <<<grid, n / 4, 0, stream>>>(out, input, residual1, bias, gamma, beta, alpha, n); } else { block.x = std::min(n, 1024); int num_trips = (n + block.x - 1) / block.x; if (num_trips == 1) { alphaAddBiasResidualPostLayerNorm<T, 1> <<<grid, block, 0, stream>>>(out, input, residual1, bias, gamma, beta, alpha, m, n); } else if (num_trips == 2) { alphaAddBiasResidualPostLayerNorm<T, 2> <<<grid, block, 0, stream>>>(out, input, residual1, bias, gamma, beta, alpha, m, n); } else { generalAlphaAddBiasResidualPostLayerNorm<T> <<<grid, block, 0, stream>>>(out, input, residual1, bias, gamma, beta, alpha, m, n); } } } template<> void invokeAlphaAddBiasResidualLayerNorm(half* out, const half* input, const half* residual1, const half* bias, const half* gamma, const half* beta, half alpha, int m, int n, cudaStream_t stream) { dim3 grid(m); dim3 block(std::min(n, 1024)); if (m >= 512 && (n == 768 || n == 1024)) { alphaAddBiasResidualPostLayerNormV2<half> <<<grid, n / 8, 0, stream>>>(out, input, residual1, bias, gamma, beta, alpha, n); } else { block.x = std::min(n, 1024); int num_trips = (n + block.x - 1) / block.x; if (num_trips == 1) { alphaAddBiasResidualPostLayerNorm<half, 1> <<<grid, block, 0, stream>>>(out, input, residual1, bias, gamma, beta, alpha, m, n); } else if (num_trips == 2) { alphaAddBiasResidualPostLayerNorm<half, 2> <<<grid, block, 0, stream>>>(out, input, residual1, bias, gamma, beta, alpha, m, n); } else { generalAlphaAddBiasResidualPostLayerNorm<half> <<<grid, block, 0, stream>>>(out, input, residual1, bias, gamma, beta, alpha, m, n); } } } #define INVOKE_ALPHA_ADD_BIAS_RES_LN(T) \ template void invokeAlphaAddBiasResidualLayerNorm(T* out, \ const T* input, \ const T* residual1, \ const T* bias, \ const T* gamma, \ const T* beta, \ T alpha, \ int m, \ int n, \ cudaStream_t stream); INVOKE_ALPHA_ADD_BIAS_RES_LN(float) INVOKE_ALPHA_ADD_BIAS_RES_LN(half) #ifdef ENABLE_BF16 INVOKE_ALPHA_ADD_BIAS_RES_LN(__nv_bfloat16) #endif template<typename T, int N> __global__ void addBiasResidualPostLayerNorm( T* out, const T* input, const T* bias, const T* gamma, const T* beta, const float layernorm_eps, int m, int n) { __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f; float variance = 0.0f; float local_out_cache[N]; #pragma unroll N for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { float local_out = (float)(add(out[blockIdx.x * n + idx], input[blockIdx.x * n + idx], ldg(&bias[idx]))); mean += local_out; // save local_out to local_out_cache to save some recompute local_out_cache[i] = local_out; idx += blockDim.x; } mean = blockReduceSum<float>(mean); if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); #pragma unroll N for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { float local_out = local_out_cache[i]; variance += (local_out - s_mean) * (local_out - s_mean); idx += blockDim.x; } variance = blockReduceSum<float>(variance); if (threadIdx.x == 0) { s_variance = variance / n + layernorm_eps; } __syncthreads(); #pragma unroll N for (int idx = threadIdx.x, i = 0; idx < n && i < N; ++i) { float local_out = local_out_cache[i]; out[blockIdx.x * n + idx] = (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(ldg(&gamma[idx])) + (float)(ldg(&beta[idx]))); idx += blockDim.x; } } template<typename T> __global__ void generalAddBiasResidualPostLayerNorm( T* out, const T* input, const T* bias, const T* gamma, const T* beta, const float layernorm_eps, int m, int n) { using T2 = typename TypeConverter<T>::Type; __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f; float variance = 0.0f; T2* out_ptr = (T2*)out; const T2* input_ptr = (const T2*)input; const T2* bias_ptr = (const T2*)bias; const T2* gamma_ptr = (const T2*)gamma; const T2* beta_ptr = (const T2*)beta; float local_out = 0.0f; for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { int id = blockIdx.x * n / 2 + idx; T2 tmp = hadd2(hadd2(out_ptr[id], input_ptr[id]), ldg(&bias_ptr[idx])); float2 local_out_fp2 = cuda_cast<float2>(tmp); local_out += local_out_fp2.x; local_out += local_out_fp2.y; // save tmp to out_ptr to save some recomputation out_ptr[id] = tmp; } mean = blockReduceSum<float>(local_out); if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { int id = blockIdx.x * n / 2 + idx; float2 local_out_fp2 = cuda_cast<float2>(out_ptr[id]); variance += (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); } variance = blockReduceSum<float>(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); for (int idx = threadIdx.x; idx < n / 2; idx += blockDim.x) { int id = blockIdx.x * n / 2 + idx; float2 local_out_fp2 = cuda_cast<float2>(out_ptr[id]); float2 gamma_val = cuda_cast<float2>(ldg(&gamma_ptr[idx])); float2 beta_val = cuda_cast<float2>(ldg(&beta_ptr[idx])); local_out_fp2.x = (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; local_out_fp2.y = (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; out_ptr[id] = cuda_cast<T2>(local_out_fp2); } } template<> __global__ void generalAddBiasResidualPostLayerNorm(float* out, const float* input, const float* bias, const float* gamma, const float* beta, const float layernorm_eps, int m, int n) { __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f; float variance = 0.0f; for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { float local_out = (float)(out[blockIdx.x * n + idx] + input[blockIdx.x * n + idx] + __ldg_func(&bias[idx])); mean += local_out; // save local_out to out to save some recompute out[blockIdx.x * n + idx] = local_out; } mean = blockReduceSum<float>(mean); if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { float local_out = out[blockIdx.x * n + idx]; variance += (local_out - s_mean) * (local_out - s_mean); } variance = blockReduceSum<float>(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { float local_out = out[blockIdx.x * n + idx]; out[blockIdx.x * n + idx] = (float)(((local_out - s_mean) * s_variance) * (float)(__ldg_func(&gamma[idx])) + (float)(__ldg_func(&beta[idx]))); } } template<typename T> __global__ void addBiasResidualPostLayerNormV2(T* out, const T* __restrict input, const T* __restrict bias, const T* __restrict gamma, const T* __restrict beta, const float layernorm_eps, int n) { using T2 = typename TypeConverter<T>::Type; const int ite = 4; const int tid = threadIdx.x; const int bid = blockIdx.x; __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f; float variance = 0.0f; T2 local_out_half2[ite]; T2* out_ptr = (T2*)out; const T2* input_ptr = (const T2*)input; const T2* bias_ptr = (const T2*)bias; const T2* gamma_ptr = (const T2*)gamma; const T2* beta_ptr = (const T2*)beta; // float sum = 0.0f; T2 sum = cuda_cast<T2>(0.0f); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; int id = bid * n / 2 + col_id; local_out_half2[i] = add(out_ptr[id], ldg(&input_ptr[id]), ldg(&bias_ptr[col_id])); sum = add(sum, local_out_half2[i]); } mean = blockReduceSum<float>((float)(sum.x + sum.y)); if (threadIdx.x == 0) { s_mean = mean / n; } __syncthreads(); float var = 0.0f; T2 s_mean_2 = cuda_cast<T2>(s_mean); #pragma unroll for (int i = 0; i < ite; i++) { local_out_half2[i] = hsub2(local_out_half2[i], s_mean_2); float v1 = (float)local_out_half2[i].x; float v2 = (float)local_out_half2[i].y; var += v1 * v1 + v2 * v2; } variance = blockReduceSum<float>(var); if (tid == 0) { s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); T2 s_var_2 = cuda_cast<T2>(s_variance); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; int id = bid * n / 2 + col_id; out_ptr[id] = fma(local_out_half2[i], s_var_2, ldg(&gamma_ptr[col_id]), ldg(&beta_ptr[col_id])); } } template<> __global__ void addBiasResidualPostLayerNormV2(float* out, const float* __restrict input, const float* __restrict bias, const float* __restrict gamma, const float* __restrict beta, const float layernorm_eps, int n) { const int ite = 4; const int tid = threadIdx.x; const int bid = blockIdx.x; __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f; float variance = 0.0f; float local_out[ite]; float sum = 0.0f; #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; int id = bid * n + col_id; local_out[i] = (float)(out[id] + __ldg_func(&input[id]) + __ldg_func(&bias[col_id])); sum += local_out[i]; } mean = blockReduceSum<float>(sum); if (tid == 0) { s_mean = mean / n; } __syncthreads(); float var = 0.0f; #pragma unroll for (int i = 0; i < ite; i++) { float diff = local_out[i] - s_mean; var += diff * diff; } variance = blockReduceSum<float>(var); if (tid == 0) { s_variance = rsqrtf(variance / n + layernorm_eps); } __syncthreads(); #pragma unroll for (int i = 0; i < ite; i++) { int col_id = i * blockDim.x + tid; int id = bid * n + col_id; out[id] = (float)((local_out[i] - s_mean) * s_variance * (float)__ldg_func(&gamma[col_id]) + (float)__ldg_func(&beta[col_id])); } } template<typename T> void invokeAddBiasResidualLayerNorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, const float layernorm_eps, int m, int n, cudaStream_t stream) { dim3 grid(m); dim3 block(std::min(n, 1024)); if (m >= 512 && (n == 768 || n == 1024)) { addBiasResidualPostLayerNormV2<T><<<grid, n / 8, 0, stream>>>(out, input, bias, gamma, beta, layernorm_eps, n); } else { block.x = std::min(n, 1024); int num_trips = (n + block.x - 1) / block.x; if (num_trips == 1) { addBiasResidualPostLayerNorm<T, 1> <<<grid, block, 0, stream>>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } else if (num_trips == 2) { addBiasResidualPostLayerNorm<T, 2> <<<grid, block, 0, stream>>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } else { generalAddBiasResidualPostLayerNorm<T> <<<grid, block, 0, stream>>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } } } template<> void invokeAddBiasResidualLayerNorm(float* out, const float* input, const float* bias, const float* gamma, const float* beta, const float layernorm_eps, int m, int n, cudaStream_t stream) { dim3 grid(m); dim3 block(std::min(n, 1024)); if (n == 768 || n == 1024) { addBiasResidualPostLayerNormV2<float> <<<grid, n / 4, 0, stream>>>(out, input, bias, gamma, beta, layernorm_eps, n); } else { block.x = std::min(n, 1024); int num_trips = (n + block.x - 1) / block.x; if (num_trips == 1) { addBiasResidualPostLayerNorm<float, 1> <<<grid, block, 0, stream>>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } else if (num_trips == 2) { addBiasResidualPostLayerNorm<float, 2> <<<grid, block, 0, stream>>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } else { generalAddBiasResidualPostLayerNorm<float> <<<grid, block, 0, stream>>>(out, input, bias, gamma, beta, layernorm_eps, m, n); } } } #define INVOKE_ADD_BIAS_RES_LN(T) \ template void invokeAddBiasResidualLayerNorm(T* out, \ const T* input, \ const T* bias, \ const T* gamma, \ const T* beta, \ const float layernorm_eps, \ int m, \ int n, \ cudaStream_t stream); INVOKE_ADD_BIAS_RES_LN(float) INVOKE_ADD_BIAS_RES_LN(half) #ifdef ENABLE_BF16 INVOKE_ADD_BIAS_RES_LN(__nv_bfloat16) #endif template<typename T, bool DYNAMIC_SCALING = false> __global__ void generalLayerNormWithPadding(const T* __restrict input, const T* __restrict gamma, const T* __restrict beta, T* normed_output, const float layernorm_eps, int m, int real_n, int padding_n, float* scale, float* dynamic_scale, const int int8_mode) { const int tid = threadIdx.x; extern __shared__ __align__(sizeof(float)) char _shmem[]; T* shmem = reinterpret_cast<T*>(_shmem); __shared__ float s_mean; __shared__ float s_variance; float mean = 0.0f; float variance = 0.0f; using Int8_Packed_T = typename packed_as<int8_t, num_elems<T>::value>::type; using Int32_Packed_T = typename packed_as<int32_t, num_elems<T>::value>::type; using Float_Packed_T = typename packed_as<float, num_elems<T>::value>::type; using Scalar_T = typename packed_as<T, 1>::type; const Float_Packed_T scale_to_int = cuda_cast<Float_Packed_T>(int8_mode == 2 ? *scale : 0.0f); float local_sum = 0.0f; for (int i = tid; i < real_n; i += blockDim.x) { local_sum += (float)(ldg(&input[blockIdx.x * padding_n + i])); } mean = blockReduceSum(local_sum); if (threadIdx.x == 0) { s_mean = mean / real_n; } __syncthreads(); float local_var_sum = 0.0f; for (int i = tid; i < real_n; i += blockDim.x) { float diff = (float)(ldg(&input[blockIdx.x * padding_n + i])) - s_mean; local_var_sum += diff * diff; } variance = blockReduceSum(local_var_sum); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / real_n + layernorm_eps); } __syncthreads(); Scalar_T abs_max(1e-6f); for (int i = tid; i < real_n; i += blockDim.x) { const int index = blockIdx.x * padding_n + i; float beta_val = (beta == nullptr) ? 0.0f : (float)ldg(&beta[i]); T val = (T)((((float)input[index] - s_mean) * s_variance) * (float)(ldg(&gamma[i])) + beta_val); if (DYNAMIC_SCALING) { abs_max = cuda_max(cuda_max<Scalar_T, T>(cuda_abs(val)), abs_max); shmem[i] = val; } else if (int8_mode == 2) { reinterpret_cast<Int8_Packed_T*>(normed_output)[index] = cuda_cast<Int8_Packed_T>(cuda_cast<Float_Packed_T>(val) * scale_to_int); } else { normed_output[index] = val; } } if (DYNAMIC_SCALING) { float abs_max_f = blockAllReduceMax(cuda_cast<float>(abs_max)); const Scalar_T dynamic_per_token_scale(float(127. / abs_max_f)); for (int i = tid; i < real_n; i += blockDim.x) { const int index = blockIdx.x * padding_n + i; reinterpret_cast<Int8_Packed_T*>(normed_output)[index] = cuda_cast<Int8_Packed_T>( cuda_cast<Float_Packed_T>(shmem[i]) * cuda_cast<Float_Packed_T>(dynamic_per_token_scale)); } if (threadIdx.x == 0) { dynamic_scale[blockIdx.x] = (*scale * abs_max_f) / 127.f; } } } template<typename T> void invokeGeneralLayerNormWithPadding(T* out, const T* input, const T* gamma, const T* beta, const float layernorm_eps, const int m, const int real_n, const int padding_n, float* scale, float* dynamic_scale, const int int8_mode, cudaStream_t stream, int opt_version) { dim3 grid(m); const bool dynamic_quant = dynamic_scale != nullptr; dim3 block(min(real_n, 1024)); /* For general cases, n is equal to hidden_units, e.g., 512/1024. Since we have warp shuffle inside the code, block.x % 32 should be 0. */ block.x = 32 * ((block.x + 31) / 32); /* should pay attention to the rsqrt precision*/ if (dynamic_quant) { size_t maxbytes = real_n * sizeof(T); if (maxbytes >= (48 << 10)) { #if USING_CUDA check_cuda_error(cudaFuncSetAttribute( generalLayerNormWithPadding<T, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes)); #endif } generalLayerNormWithPadding<T, true><<<grid, block, maxbytes, stream>>>( input, gamma, beta, out, layernorm_eps, m, real_n, padding_n, scale, dynamic_scale, int8_mode); } else { generalLayerNormWithPadding<T, false><<<grid, block, 0, stream>>>( input, gamma, beta, out, layernorm_eps, m, real_n, padding_n, scale, dynamic_scale, int8_mode); } } #define INVOKE_GENERAL_LN_WITH_PADDING(T) \ template void invokeGeneralLayerNormWithPadding(T* out, \ const T* input, \ const T* gamma, \ const T* beta, \ const float layernorm_eps, \ const int m, \ const int real_n, \ const int padding_n, \ float* scale, \ float* dynamic_scale, \ const int int8_mode, \ cudaStream_t stream, \ int opt_version); INVOKE_GENERAL_LN_WITH_PADDING(float) INVOKE_GENERAL_LN_WITH_PADDING(half) #ifdef ENABLE_BF16 INVOKE_GENERAL_LN_WITH_PADDING(__nv_bfloat16) #endif }