maga_transformer/cpp/devices/arm_impl/ArmLayerNormOp.cc (1,337 lines of code) (raw):

#include "maga_transformer/cpp/devices/arm_impl/ArmDevice.h" #include "maga_transformer/cpp/devices/DeviceFactory.h" #include "maga_transformer/cpp/core/allocator.h" #include "maga_transformer/cpp/core/cpu_allocator.h" #include "maga_transformer/cpp/devices/utils/DebugUtils.h" #include <cstring> #include <arm_neon.h> #include <algorithm> //std::all_of namespace rtp_llm { template<typename T> void add_residual_bias(void* norm_out, const void* input, void* residual, int m, int n) { for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { ((T*)norm_out)[i * n + j] = ((T*)input)[i * n + j] + ((T*)residual)[i * n + j]; } } } void add_residual_bias_float(float* norm_out, float* input, float* residual, float* bias, int n){ int d = 0; for (; d <= n - 16; d += 16) { float32x4x4_t regs = vld1q_f32_x4(input + d); float32x4x4_t regs_residual_bias; if(residual) { regs_residual_bias= vld1q_f32_x4(residual + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } if(bias){ regs_residual_bias = vld1q_f32_x4(bias + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } vst1q_f32_x4(norm_out+d,regs); } for (; d < n; ++d) { float val = input[d]; if(residual) val+=residual[d]; if(bias) val+=bias[d]; norm_out[d] = val; } } void add_residual_bias_fp16(__fp16* norm_out, __fp16* input, __fp16* residual, __fp16* bias, int n) { int d = 0; for (; d <= n - 32; d += 32) { float16x8x4_t regs = vld1q_f16_x4(input + d); float16x8x4_t regs_residual_bias; if(residual){ regs_residual_bias = vld1q_f16_x4(residual + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f16(regs.val[i], regs_residual_bias.val[i]); } } if(bias){ regs_residual_bias = vld1q_f16_x4(bias + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f16(regs.val[i], regs_residual_bias.val[i]); } } vst1q_f16_x4(norm_out+d, regs); } for (; d < n; ++d) { float val = input[d]; if(residual) val+=residual[d]; if(bias) val+=bias[d]; norm_out[d] = val; } } void convert_fp16_to_float(const __fp16* input, float* output, int length) { int d = 0; for (; d <= length - 32; d += 32) { // Load 32 fp16 values float16x8_t fp16_vec0 = vld1q_f16(&input[d]); float16x8_t fp16_vec1 = vld1q_f16(&input[d + 8]); float16x8_t fp16_vec2 = vld1q_f16(&input[d + 16]); float16x8_t fp16_vec3 = vld1q_f16(&input[d + 24]); // Convert to float32 float32x4_t float_vec0_low = vcvt_f32_f16(vget_low_f16(fp16_vec0)); float32x4_t float_vec0_high = vcvt_f32_f16(vget_high_f16(fp16_vec0)); float32x4_t float_vec1_low = vcvt_f32_f16(vget_low_f16(fp16_vec1)); float32x4_t float_vec1_high = vcvt_f32_f16(vget_high_f16(fp16_vec1)); float32x4_t float_vec2_low = vcvt_f32_f16(vget_low_f16(fp16_vec2)); float32x4_t float_vec2_high = vcvt_f32_f16(vget_high_f16(fp16_vec2)); float32x4_t float_vec3_low = vcvt_f32_f16(vget_low_f16(fp16_vec3)); float32x4_t float_vec3_high = vcvt_f32_f16(vget_high_f16(fp16_vec3)); // Store results vst1q_f32(&output[d], float_vec0_low); vst1q_f32(&output[d + 4], float_vec0_high); vst1q_f32(&output[d + 8], float_vec1_low); vst1q_f32(&output[d + 12], float_vec1_high); vst1q_f32(&output[d + 16], float_vec2_low); vst1q_f32(&output[d + 20], float_vec2_high); vst1q_f32(&output[d + 24], float_vec3_low); vst1q_f32(&output[d + 28], float_vec3_high); } for (; d < length; ++d) { output[d] = static_cast<float>(input[d]); } } void convert_float_to_fp16(const float *input, __fp16 *output, int length) { int d=0; for (; d <= length - 32; d += 32) { float32x4x4_t vec_float_low = vld1q_f32_x4(input + d); float32x4x4_t vec_float_high = vld1q_f32_x4(input + d + 16); float16x4_t vec_fp16_low1 = vcvt_f16_f32(vec_float_low.val[0]); float16x4_t vec_fp16_high1 = vcvt_f16_f32(vec_float_low.val[1]); float16x4_t vec_fp16_low2 = vcvt_f16_f32(vec_float_low.val[2]); float16x4_t vec_fp16_high2 = vcvt_f16_f32(vec_float_low.val[3]); float16x4_t vec_fp16_low3 = vcvt_f16_f32(vec_float_high.val[0]); float16x4_t vec_fp16_high3 = vcvt_f16_f32(vec_float_high.val[1]); float16x4_t vec_fp16_low4 = vcvt_f16_f32(vec_float_high.val[2]); float16x4_t vec_fp16_high4 = vcvt_f16_f32(vec_float_high.val[3]); float16x8_t result_low1 = vcombine_f16(vec_fp16_low1,vec_fp16_high1); float16x8_t result_high1 = vcombine_f16(vec_fp16_low2,vec_fp16_high2); float16x8_t result_low2 = vcombine_f16(vec_fp16_low3,vec_fp16_high3); float16x8_t result_high2 = vcombine_f16(vec_fp16_low4,vec_fp16_high4); vst1q_f16(output+d ,result_low1); vst1q_f16(output+d+8 ,result_high1); vst1q_f16(output+d+16,result_low2); vst1q_f16(output+d+24,result_high2); } for (; d < length; ++d) { output[d] = static_cast<__fp16>(input[d]); } } void RMSNorm_isoutput(int n, float* before_norm_output, float* input, float* norm_out, const float* gamma, const float* beta, float* residual, float* bias, const double eps){ float32x4_t square_sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vdupq_n_f32(0.0f); } int d = 0; for (; d <= n - 16; d += 16) { float32x4x4_t regs = vld1q_f32_x4(input + d); float32x4x4_t regs_residual_bias; if(residual) { regs_residual_bias= vld1q_f32_x4(residual + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } if(bias){ regs_residual_bias = vld1q_f32_x4(bias + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } if(before_norm_output && before_norm_output != norm_out) vst1q_f32_x4(before_norm_output + d, regs); #pragma unroll for (int i = 0; i < 4; ++i) { //add_bias_residual square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i])); } } float32_t square_sum = 0.0f; square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]); square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]); square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]); for (; d < n; ++d) { float val = input[d]; if(residual) val+=residual[d]; if(bias) val+=bias[d]; if(before_norm_output && before_norm_output != norm_out) before_norm_output[d] = val; square_sum += val * val; } // #pragma unroll for (int i = 0; i < 4; ++i) { square_sum += square_sum_v[0][i]; } float rms = square_sum / n; rms = 1.0f / std::sqrt(rms + eps); float32x4_t rms_v = vdupq_n_f32(rms); // normalization d = 0; float32x4x4_t input_v; float32x4x4_t gamma_v; float32x4x4_t beta_v; for (; d <= n - 16; d += 16) { float32x4x4_t Residual; float32x4x4_t Bias; input_v = vld1q_f32_x4(input + d); gamma_v = vld1q_f32_x4(gamma + d); if(residual) Residual = vld1q_f32_x4(residual + d); if(bias) Bias = vld1q_f32_x4(bias + d); if(beta) beta_v = vld1q_f32_x4(beta + d); #pragma unroll for (int i = 0; i < 4; ++i) { if(residual) input_v.val[i] = vaddq_f32(input_v.val[i], Residual.val[i]); if(bias) input_v.val[i] = vaddq_f32(input_v.val[i], Bias.val[i]); //input_v.val[i] = vmulq_f32(input_v.val[i],scale_v); input_v.val[i] = vmulq_f32(input_v.val[i], rms_v); input_v.val[i] = vmulq_f32(input_v.val[i], gamma_v.val[i]); if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]); } vst1q_f32_x4(norm_out + d, input_v); } float input_residual_bias; for (; d < n; ++d) { input_residual_bias = input[d]; if(residual) input_residual_bias += residual[d]; if(bias) input_residual_bias += bias[d]; norm_out[d] = input_residual_bias * rms * gamma[d]; if(beta) norm_out[d] = norm_out[d] + beta[d]; } } void RMSNorm_Nogamma_isoutput(int n, float* before_norm_output,const float* input, float* norm_out, const float* beta, float* residual, float* bias, const double eps){ float32x4_t square_sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vdupq_n_f32(0.0f); } int d = 0; for (; d <= n - 16; d += 16) { float32x4x4_t regs = vld1q_f32_x4(input + d); float32x4x4_t regs_residual_bias; if(residual) { regs_residual_bias= vld1q_f32_x4(residual + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } if(bias){ regs_residual_bias = vld1q_f32_x4(bias + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } if(before_norm_output)vst1q_f32_x4(before_norm_output + d, regs); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i])); } } float32_t square_sum = 0.0f; for (; d < n; ++d) { float val = input[d]; if(residual) val+=residual[d]; if(bias) val+=bias[d]; if(before_norm_output) before_norm_output[d] = val; square_sum += val * val; } square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]); square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]); square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum += square_sum_v[0][i]; } float rms = square_sum / n; rms = 1.0f / std::sqrt(rms + eps); float32x4_t rms_v = vdupq_n_f32(rms); // normalization d = 0; float32x4x4_t input_v; float32x4x4_t beta_v; for (; d <= n - 16; d += 16) { float32x4x4_t Residual; float32x4x4_t Bias; input_v = vld1q_f32_x4(input + d); if(residual) Residual = vld1q_f32_x4(residual + d); if(bias) Bias = vld1q_f32_x4(bias + d); if(beta) beta_v = vld1q_f32_x4(beta + d); #pragma unroll for (int i = 0; i < 4; ++i) { if(residual) input_v.val[i] = vaddq_f32(input_v.val[i], Residual.val[i]); if(bias)input_v.val[i] = vaddq_f32(input_v.val[i], Bias.val[i]); input_v.val[i] = vmulq_f32(input_v.val[i], rms_v); if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]); } vst1q_f32_x4(norm_out + d, input_v); } float input_residual_bias; for (; d < n; ++d) { input_residual_bias = input[d]; if(residual) input_residual_bias += residual[d]; if(bias) input_residual_bias += bias[d]; norm_out[d] = input_residual_bias * rms; if(beta) norm_out[d] = norm_out[d] + beta[d]; } } void RMSNorm(int n, const float* input, float* norm_out, const float* gamma, const float* beta,const double eps){ float32x4_t square_sum_v[4]; //float32x4_t scale_v = vdupq_n_f32(2.0f); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vdupq_n_f32(0.0f); } int d = 0; for (; d <= n - 16; d += 16) { float32x4x4_t regs = vld1q_f32_x4(input + d); #pragma unroll for (int i = 0; i < 4; ++i) { //regs.val[i] = vmulq_f32(regs.val[i],scale_v); square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i])); } } float32_t square_sum = 0.0f; square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]); square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]); square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum += square_sum_v[0][i]; } for (; d < n; ++d) { square_sum += input[d] * input[d]; } float rms = square_sum / n; rms = 1.0f / std::sqrt(rms + eps); float32x4_t rms_v = vdupq_n_f32(rms); // normalization d = 0; float32x4x4_t input_v; float32x4x4_t gamma_v; float32x4x4_t beta_v; for (; d <= n - 16; d += 16) { input_v = vld1q_f32_x4(input + d); gamma_v = vld1q_f32_x4(gamma + d); if(beta) beta_v = vld1q_f32_x4(beta + d); #pragma unroll for (int i = 0; i < 4; ++i) { //input_v.val[i] = vmulq_f32(input_v.val[i],scale_v); input_v.val[i] = vmulq_f32(input_v.val[i], rms_v); input_v.val[i] = vmulq_f32(input_v.val[i], gamma_v.val[i]); if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]); } vst1q_f32_x4(norm_out + d, input_v); } for (; d < n; ++d) { norm_out[d] = input[d] * rms * gamma[d]; if(beta) norm_out[d] = norm_out[d] + beta[d]; } } void RMSNorm_Nogamma(int n, const float* input, float* norm_out, const float* beta,const double eps){ float32x4_t square_sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vdupq_n_f32(0.0f); } int d = 0; for (; d <= n - 16; d += 16) { float32x4x4_t regs = vld1q_f32_x4(input + d); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i])); } } float32_t square_sum = 0.0f; for (; d < n; ++d) { float val = input[d]; square_sum += val * val; } square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]); square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]); square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum += square_sum_v[0][i]; } float rms = square_sum / n; rms = 1.0f / std::sqrt(rms + eps); float32x4_t rms_v = vdupq_n_f32(rms); // normalization d = 0; float32x4x4_t beta_v; for (; d <= n - 16; d += 16) { float32x4x4_t input_v = vld1q_f32_x4(input + d); if(beta) beta_v = vld1q_f32_x4(beta + d); #pragma unroll for (int i = 0; i < 4; ++i) { input_v.val[i] = vmulq_f32(input_v.val[i], rms_v); if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]); } vst1q_f32_x4(norm_out + d, input_v); } for (; d < n; ++d) { norm_out[d] = input[d] * rms ; if(beta) norm_out[d] = norm_out[d] + beta[d]; } } void layerNorm(int n,const float* input, float* norm_out, const float* gamma, const float* beta, const double eps){ float32x4_t sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vdupq_n_f32(0.0f); } float32x4_t square_sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vdupq_n_f32(0.0f); } int d = 0; for (; d <= n - 16; d += 16) { float32x4x4_t regs = vld1q_f32_x4(input + d); #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vaddq_f32(sum_v[i], regs.val[i]); square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i])); } } float32_t sum = 0.0f; float32_t square_sum = 0.0f; for (; d < n; ++d) { float val = input[d]; sum += val; square_sum += val * val; } sum_v[0] = vaddq_f32(sum_v[0], sum_v[1]); sum_v[2] = vaddq_f32(sum_v[2], sum_v[3]); sum_v[0] = vaddq_f32(sum_v[0], sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { sum += sum_v[0][i]; } square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]); square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]); square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum += square_sum_v[0][i]; } float mean = sum / n; float variance = square_sum / n; variance = 1.0f / std::sqrt(variance - mean * mean + eps); float32x4_t mean_v = vdupq_n_f32(mean); float32x4_t variance_v = vdupq_n_f32(variance); // normalization d = 0; float32x4x4_t gamma_v; float32x4x4_t beta_v; for (; d <= n - 16; d += 16) { float32x4x4_t input_v = vld1q_f32_x4(input + d); if(gamma) gamma_v = vld1q_f32_x4(gamma + d); if(beta) beta_v = vld1q_f32_x4(beta + d); #pragma unroll for (int i = 0; i < 4; ++i) { input_v.val[i] = vsubq_f32(input_v.val[i], mean_v); if(gamma) input_v.val[i] = vmulq_f32(input_v.val[i], gamma_v.val[i]); input_v.val[i] = vmulq_f32(input_v.val[i], variance_v); if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]); } vst1q_f32_x4(norm_out + d, input_v); } for (; d < n; ++d) { if(gamma&&beta) norm_out[d] = (input[d] - mean) * variance * gamma[d]+ beta[d];//with gamma and beta else if(gamma&&!beta) norm_out[d] = (input[d] - mean) * gamma[d] * variance; else if(!gamma&&!beta) norm_out[d] = (input[d] - mean) * variance; else norm_out[d] = (input[d] - mean) * gamma[d] * variance+ beta[d]; } } void layerNorm_Nogamma(int n,const float* input, float* norm_out, const float* beta, const double eps){ float32x4_t sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vdupq_n_f32(0.0f); } float32x4_t square_sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vdupq_n_f32(0.0f); } //#pragma omp parallel for int d = 0; for (; d <= n - 16; d += 16) { float32x4x4_t regs = vld1q_f32_x4(input + d); #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vaddq_f32(sum_v[i], regs.val[i]); square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i])); } } float32_t sum = 0.0f; float32_t square_sum = 0.0f; for (; d < n; ++d) { float val = input[d]; sum += val; square_sum += val * val; } sum_v[0] = vaddq_f32(sum_v[0], sum_v[1]); sum_v[2] = vaddq_f32(sum_v[2], sum_v[3]); sum_v[0] = vaddq_f32(sum_v[0], sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { sum += sum_v[0][i]; } square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]); square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]); square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum += square_sum_v[0][i]; } float mean = sum / n; float variance = square_sum / n; variance = 1.0f / std::sqrt(variance - mean * mean + eps); float32x4_t mean_v = vdupq_n_f32(mean); float32x4_t variance_v = vdupq_n_f32(variance); // normalization d = 0; float32x4x4_t beta_v; for (; d <= n - 16; d += 16) { float32x4x4_t input_v = vld1q_f32_x4(input + d); if(beta) beta_v = vld1q_f32_x4(beta + d); #pragma unroll for (int i = 0; i < 4; ++i) { input_v.val[i] = vsubq_f32(input_v.val[i], mean_v); input_v.val[i] = vmulq_f32(input_v.val[i], variance_v); if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]); } vst1q_f32_x4(norm_out + d, input_v); } for (; d < n; ++d) { if(beta) norm_out[d] = (input[d] - mean) * variance + beta[d]; else norm_out[d] = (input[d] - mean) * variance ; } } void layerNorm_isoutput(int n,const float* input, float* norm_out, const float* gamma, const float* beta, float* residual , float* bias, const double eps){ float32x4_t sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vdupq_n_f32(0.0f); } float32x4_t square_sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vdupq_n_f32(0.0f); } int d = 0; for (; d <= n - 16; d += 16) { float32x4x4_t regs = vld1q_f32_x4(input + d); float32x4x4_t regs_residual_bias; if(residual) { regs_residual_bias= vld1q_f32_x4(residual + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } if(bias){ regs_residual_bias = vld1q_f32_x4(bias + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vaddq_f32(sum_v[i], regs.val[i]); square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i])); } } float32_t sum = 0.0f; float32_t square_sum = 0.0f; for (; d < n; ++d) { float val = input[d]; if(residual) val+=residual[d]; if(bias) val+=bias[d]; sum += val; square_sum += val * val; } sum_v[0] = vaddq_f32(sum_v[0], sum_v[1]); sum_v[2] = vaddq_f32(sum_v[2], sum_v[3]); sum_v[0] = vaddq_f32(sum_v[0], sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { sum += sum_v[0][i]; } square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]); square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]); square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum += square_sum_v[0][i]; } float mean = sum / n; float variance = square_sum / n; variance = 1.0f / std::sqrt(variance - mean * mean + eps); float32x4_t mean_v = vdupq_n_f32(mean); float32x4_t variance_v = vdupq_n_f32(variance); // normalization d = 0; float32x4x4_t gamma_v; float32x4x4_t beta_v; for (; d <= n - 16; d += 16) { float32x4x4_t input_v = vld1q_f32_x4(input + d); float32x4x4_t Residual; float32x4x4_t Bias; if(residual) Residual = vld1q_f32_x4(residual + d); if(bias) Bias = vld1q_f32_x4(bias + d); gamma_v = vld1q_f32_x4(gamma + d); if(beta) beta_v = vld1q_f32_x4(beta + d); #pragma unroll for (int i = 0; i < 4; ++i) { input_v.val[i] = vaddq_f32(input_v.val[i], Residual.val[i]); input_v.val[i] = vaddq_f32(input_v.val[i], Bias.val[i]); input_v.val[i] = vsubq_f32(input_v.val[i], mean_v); input_v.val[i] = vmulq_f32(input_v.val[i], gamma_v.val[i]); input_v.val[i] = vmulq_f32(input_v.val[i], variance_v); if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]); } vst1q_f32_x4(norm_out + d, input_v); } float input_residual_bias; for (; d < n; ++d) { input_residual_bias = input[d]; if(residual) input_residual_bias += residual[d]; if(bias) input_residual_bias += bias[d]; norm_out[d] = (input_residual_bias - mean) * gamma[d] * variance; if(beta) norm_out[d] += beta[d]; } } void layerNorm_Nogamma_isoutput(int n,const float* input, float* norm_out, const float* beta, float* residual , float* bias, const double eps){ float32x4_t sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vdupq_n_f32(0.0f); } float32x4_t square_sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vdupq_n_f32(0.0f); } int d = 0; for (; d <= n - 16; d += 16) { float32x4x4_t regs = vld1q_f32_x4(input + d); float32x4x4_t regs_residual_bias; if(residual) { regs_residual_bias= vld1q_f32_x4(residual + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } if(bias){ regs_residual_bias = vld1q_f32_x4(bias + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vaddq_f32(sum_v[i], regs.val[i]); square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i])); } } float32_t sum = 0.0f; float32_t square_sum = 0.0f; for (; d < n; ++d) { float val = input[d]; if(residual) val+=residual[d]; if(bias) val+=bias[d]; sum += val; square_sum += val * val; } sum_v[0] = vaddq_f32(sum_v[0], sum_v[1]); sum_v[2] = vaddq_f32(sum_v[2], sum_v[3]); sum_v[0] = vaddq_f32(sum_v[0], sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { sum += sum_v[0][i]; } square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]); square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]); square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum += square_sum_v[0][i]; } float mean = sum / n; float variance = square_sum / n; variance = 1.0f / std::sqrt(variance - mean * mean + eps); float32x4_t mean_v = vdupq_n_f32(mean); float32x4_t variance_v = vdupq_n_f32(variance); // normalization d = 0; float32x4x4_t beta_v; for (; d <= n - 16; d += 16) { float32x4x4_t input_v = vld1q_f32_x4(input + d); float32x4x4_t Residual; float32x4x4_t Bias; if(residual) Residual = vld1q_f32_x4(residual + d); if(bias) Bias = vld1q_f32_x4(bias + d); if(beta) beta_v = vld1q_f32_x4(beta + d); #pragma unroll for (int i = 0; i < 4; ++i) { input_v.val[i] = vaddq_f32(input_v.val[i], Residual.val[i]); input_v.val[i] = vaddq_f32(input_v.val[i], Bias.val[i]); input_v.val[i] = vsubq_f32(input_v.val[i], mean_v); input_v.val[i] = vmulq_f32(input_v.val[i], variance_v); if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]); } vst1q_f32_x4(norm_out + d, input_v); } float input_residual_bias; for (; d < n; ++d) { input_residual_bias = input[d]; if(residual) input_residual_bias += residual[d]; if(bias) input_residual_bias += bias[d]; norm_out[d] = (input_residual_bias - mean) * variance; if(beta) norm_out[d] += beta[d]; } } void layerNorm_isoutput_unnormedout(int n,const float* input, float* norm_out, const float* gamma, const float* beta, float* residual , float* bias, float* before_norm_output, const double eps){ float32x4_t sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vdupq_n_f32(0.0f); } float32x4_t square_sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vdupq_n_f32(0.0f); } int d = 0; for (; d <= n - 16; d += 16) { float32x4x4_t regs = vld1q_f32_x4(input + d); float32x4x4_t regs_residual_bias; if(residual) { regs_residual_bias= vld1q_f32_x4(residual + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } if(bias){ regs_residual_bias = vld1q_f32_x4(bias + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } vst1q_f32_x4(before_norm_output+d,regs); #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vaddq_f32(sum_v[i], regs.val[i]); square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i])); } } float32_t sum = 0.0f; float32_t square_sum = 0.0f; for (; d < n; ++d) { float val = input[d]; if(residual) val+=residual[d]; if(bias) val+=bias[d]; before_norm_output[d] = val; sum += val; square_sum += val * val; } sum_v[0] = vaddq_f32(sum_v[0], sum_v[1]); sum_v[2] = vaddq_f32(sum_v[2], sum_v[3]); sum_v[0] = vaddq_f32(sum_v[0], sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { sum += sum_v[0][i]; } square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]); square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]); square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum += square_sum_v[0][i]; } float mean = sum / n; float variance = square_sum / n; variance = 1.0f / std::sqrt(variance - mean * mean + eps); float32x4_t mean_v = vdupq_n_f32(mean); float32x4_t variance_v = vdupq_n_f32(variance); // normalization d = 0; float32x4x4_t gamma_v; float32x4x4_t beta_v; for (; d <= n - 16; d += 16) { float32x4x4_t input_v = vld1q_f32_x4(input + d); float32x4x4_t Residual; float32x4x4_t Bias; if(residual) Residual = vld1q_f32_x4(residual + d); if(bias) Bias = vld1q_f32_x4(bias + d); gamma_v = vld1q_f32_x4(gamma + d); if(beta) beta_v = vld1q_f32_x4(beta + d); #pragma unroll for (int i = 0; i < 4; ++i) { input_v.val[i] = vaddq_f32(input_v.val[i], Residual.val[i]); input_v.val[i] = vaddq_f32(input_v.val[i], Bias.val[i]); input_v.val[i] = vsubq_f32(input_v.val[i], mean_v); input_v.val[i] = vmulq_f32(input_v.val[i], gamma_v.val[i]); input_v.val[i] = vmulq_f32(input_v.val[i], variance_v); if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]); } vst1q_f32_x4(norm_out + d, input_v); } float input_residual_bias; for (; d < n; ++d) { input_residual_bias = input[d]; if(residual) input_residual_bias += residual[d]; if(bias) input_residual_bias += bias[d]; norm_out[d] = (input_residual_bias - mean) * gamma[d] * variance; if(beta) norm_out[d] += beta[d]; } } void layerNorm_Nogamma_isoutput_unnormedout(int n,const float* input, float* norm_out, const float* beta, float* residual , float* bias, float* before_norm_output, const double eps){ float32x4_t sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vdupq_n_f32(0.0f); } float32x4_t square_sum_v[4]; #pragma unroll for (int i = 0; i < 4; ++i) { square_sum_v[i] = vdupq_n_f32(0.0f); } int d = 0; for (; d <= n - 16; d += 16) { float32x4x4_t regs = vld1q_f32_x4(input + d); float32x4x4_t regs_residual_bias; if(residual) { regs_residual_bias= vld1q_f32_x4(residual + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } if(bias){ regs_residual_bias = vld1q_f32_x4(bias + d); #pragma unroll for (int i = 0; i < 4; ++i) { regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]); } } vst1q_f32_x4(before_norm_output+d,regs); #pragma unroll for (int i = 0; i < 4; ++i) { sum_v[i] = vaddq_f32(sum_v[i], regs.val[i]); square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i])); } } float32_t sum = 0.0f; float32_t square_sum = 0.0f; for (; d < n; ++d) { float val = input[d]; if(residual) val+=residual[d]; if(bias) val+=bias[d]; before_norm_output[d] = val; sum += val; square_sum += val * val; } sum_v[0] = vaddq_f32(sum_v[0], sum_v[1]); sum_v[2] = vaddq_f32(sum_v[2], sum_v[3]); sum_v[0] = vaddq_f32(sum_v[0], sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { sum += sum_v[0][i]; } square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]); square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]); square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]); #pragma unroll for (int i = 0; i < 4; ++i) { square_sum += square_sum_v[0][i]; } float mean = sum / n; float variance = square_sum / n; variance = 1.0f / std::sqrt(variance - mean * mean + eps); float32x4_t mean_v = vdupq_n_f32(mean); float32x4_t variance_v = vdupq_n_f32(variance); // normalization d = 0; float32x4x4_t beta_v; for (; d <= n - 16; d += 16) { float32x4x4_t input_v = vld1q_f32_x4(input + d); float32x4x4_t Residual; float32x4x4_t Bias; if(residual) Residual = vld1q_f32_x4(residual + d); if(bias) Bias = vld1q_f32_x4(bias + d); if(beta) beta_v = vld1q_f32_x4(beta + d); #pragma unroll for (int i = 0; i < 4; ++i) { input_v.val[i] = vaddq_f32(input_v.val[i], Residual.val[i]); input_v.val[i] = vaddq_f32(input_v.val[i], Bias.val[i]); input_v.val[i] = vsubq_f32(input_v.val[i], mean_v); input_v.val[i] = vmulq_f32(input_v.val[i], variance_v); if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]); } vst1q_f32_x4(norm_out + d, input_v); } float input_residual_bias; for (; d < n; ++d) { input_residual_bias = input[d]; if(residual) input_residual_bias += residual[d]; if(bias) input_residual_bias += bias[d]; norm_out[d] = (input_residual_bias - mean) * variance; if(beta) norm_out[d] += beta[d]; } } //FP16 will introduce unacceptable cumulative errors. LayernormOutput ArmCpuDevice::layernorm(const LayernormParams& params) { BufferPtr input = params.input; BufferPtr norm_output = input; const auto& weights = params.norm_weight; //before_norm_output is using for pre-norm,currently not implemented void* gamma = weights ? weights->get().gamma.get()->data() : nullptr; // void* beta = (weights && weights->get().beta) ? weights->get().beta.get()->data() : nullptr; const auto eps = params.eps; void* before_norm_output= params.before_norm_output ? params.before_norm_output->data() : nullptr; void* residual = params.residual1 ? params.residual1->get().data() : nullptr; void* bias = params.bias.has_value() ? params.bias->get().data() : nullptr; bool is_output = (params.residual1.has_value() || params.bias.has_value()); int numThreads = omp_get_num_threads();; const auto norm_type = params.norm_type; int m = input->shape()[0]; int n = input->shape()[1]; const auto data_type = input->type(); if (!params.is_inplace && params.qscheme == QScheme::NoQuantize) { norm_output = allocateBufferLike(*params.input); } else if (params.qscheme == Qint8PerToken) { throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); } int convert_gamma = 0; int convert_beta = 0; int convert_bias = 0; if (data_type == DataType::TYPE_FP32) { if (gamma) { if (weights->get().gamma.get()->type() == DataType::TYPE_FP16) { convert_gamma = 1; } } if (beta) { if (weights->get().beta.get()->type() == DataType::TYPE_FP16) { convert_beta = 1; } } if (bias) { if (params.bias->get().type() == DataType::TYPE_FP16) { convert_bias = 1; } } } // for BERT // before_norm_output params.return_norm_output bias/residual exist // . . F // layernorm(input)->normed_output // F . T // layernorm(input+bias+residual)->normed_output // T T T // layernorm(input+bias+residual)->before_norm_output // layernorm(input+bias+residual)->normed_output // T F T // (input+bias+residual)->before_norm_output // layernorm(input+bias+residual)->normed_output if (norm_type == NormType::layernorm && (convert_gamma || convert_beta || convert_bias)) { float* gamma_converted = new float[n]; if (gamma) { if (convert_gamma) { convert_fp16_to_float((__fp16*)gamma,gamma_converted,n); } else { for (int d = 0; d < n; ++d) { gamma_converted[d] = static_cast<float>(((float*)gamma)[d]); } } } if(!is_output){//. . F if (!gamma || std::all_of((float *)gamma_converted, (float *)gamma_converted + n, [](float value) { return value == 1.0f; })){ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if(convert_beta) { float* beta_converted = new float[n]; convert_fp16_to_float((__fp16*)beta,beta_converted,n); layerNorm_Nogamma(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, beta_converted, eps); delete[] beta_converted; } else { layerNorm_Nogamma(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)beta, eps); } } return LayernormOutput({norm_output, params.before_norm_output}); }//(gamma =1,1......)OR (no gamma) else{ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if(convert_beta) { float* beta_converted = new float[n]; convert_fp16_to_float((__fp16*)beta,beta_converted,n); layerNorm(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, gamma_converted, beta_converted, eps); delete[] beta_converted; } else { layerNorm(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, gamma_converted, (float*)beta, eps); } } return LayernormOutput({norm_output, params.before_norm_output}); } } else if(!before_norm_output){//add bias residual //F . T if (!gamma || std::all_of((float *)gamma_converted, (float *)gamma_converted + n, [](float value) { return value == 1.0f; })){ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if (convert_beta && convert_bias) { float* beta_converted = new float[n]; float* bias_converted = new float[n]; convert_fp16_to_float((__fp16*)beta,beta_converted,n); convert_fp16_to_float((__fp16*)bias,bias_converted,n); layerNorm_Nogamma_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, beta_converted,(residual != nullptr) ? (float*)residual + i*n : (float*)residual,bias_converted, eps); delete[] beta_converted; delete[] bias_converted; } else { layerNorm_Nogamma_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)beta,(residual != nullptr) ? (float*)residual + i*n : (float*)residual,(float*)bias, eps); } } return LayernormOutput({norm_output, params.before_norm_output}); }//(gamma =1,1......)OR (no gamma) else{ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if (convert_beta && convert_bias) { float* beta_converted = new float[n]; float* bias_converted = new float[n]; convert_fp16_to_float((__fp16*)beta,beta_converted,n); convert_fp16_to_float((__fp16*)bias,bias_converted,n); layerNorm_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, gamma_converted, beta_converted,(residual != nullptr) ? ((float*)residual + i*n) : nullptr ,bias_converted, eps); delete[] beta_converted; delete[] bias_converted; } else { layerNorm_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)gamma, (float*)beta,(residual != nullptr) ? ((float*)residual + i*n) : nullptr ,(float*)bias, eps); } } return LayernormOutput({norm_output, params.before_norm_output}); } } else if(params.return_normed_output){// T T T if (!gamma || std::all_of((float *)gamma_converted, (float *)gamma_converted + n, [](float value) { return value == 1.0f; })){ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if (convert_beta && convert_bias) { float* beta_converted = new float[n]; float* bias_converted = new float[n]; convert_fp16_to_float((__fp16*)beta,beta_converted,n); convert_fp16_to_float((__fp16*)bias,bias_converted,n); layerNorm_Nogamma_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)beta,(residual != nullptr) ? ((float*)residual + i*n) : (float*)residual,(float*)bias, eps); if(before_norm_output != norm_output->data())std::memcpy((float*)before_norm_output + i*n,(float*)norm_output->data() + i*n,n * sizeof(float)); delete[] beta_converted; delete[] bias_converted; } else { layerNorm_Nogamma_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)beta,(residual != nullptr) ? ((float*)residual + i*n) : (float*)residual,(float*)bias, eps); if(before_norm_output != norm_output->data())std::memcpy((float*)before_norm_output + i*n,(float*)norm_output->data() + i*n,n * sizeof(float)); } } return LayernormOutput({norm_output, params.before_norm_output}); }//gamma =1,1...... No gamma else{ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if (convert_beta && convert_bias) { float* beta_converted = new float[n]; float* bias_converted = new float[n]; convert_fp16_to_float((__fp16*)beta,beta_converted,n); convert_fp16_to_float((__fp16*)bias,bias_converted,n); layerNorm_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, gamma_converted, beta_converted,(residual != nullptr) ? ((float*)residual + i*n) : (float*)residual ,bias_converted, eps); if(before_norm_output != norm_output->data()) std::memcpy((float*)before_norm_output + i*n,(float*)norm_output->data() + i*n,n * sizeof(float)); delete[] beta_converted; delete[] bias_converted; } else { layerNorm_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)gamma, (float*)beta,(residual != nullptr) ? ((float*)residual + i*n) : (float*)residual ,(float*)bias, eps); if(before_norm_output != norm_output->data()) std::memcpy((float*)before_norm_output + i*n,(float*)norm_output->data() + i*n,n * sizeof(float)); } } return LayernormOutput({norm_output, params.before_norm_output}); } } else{ //T F T if (!gamma || std::all_of((float *)gamma_converted, (float *)gamma_converted + n, [](float value) { return value == 1.0f; })){ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if (convert_beta && convert_bias) { float* beta_converted = new float[n]; float* bias_converted = new float[n]; convert_fp16_to_float((__fp16*)beta,beta_converted,n); convert_fp16_to_float((__fp16*)bias,bias_converted,n); layerNorm_Nogamma_isoutput_unnormedout(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, beta_converted, (residual != nullptr) ? ((float*)residual + i*n) : (float*)residual,bias_converted, ((float*)before_norm_output + i*n), eps); delete[] beta_converted; delete[] bias_converted; } else { layerNorm_Nogamma_isoutput_unnormedout(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)beta,(residual != nullptr) ? ((float*)residual + i*n) : (float*)residual,(float*)bias, ((float*)before_norm_output + i*n), eps); } } return LayernormOutput({norm_output, params.before_norm_output}); }//gamma =1,1...... No gamma else{ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if (convert_beta && convert_bias) { float* beta_converted = new float[n]; float* bias_converted = new float[n]; convert_fp16_to_float((__fp16*)beta,beta_converted,n); convert_fp16_to_float((__fp16*)bias,bias_converted,n); layerNorm_isoutput_unnormedout(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, gamma_converted, beta_converted,(residual != nullptr) ? ((float*)residual + i*n) : (float*)residual ,bias_converted, ((float*)before_norm_output+ i*n),eps); delete[] beta_converted; delete[] bias_converted; } else { layerNorm_isoutput_unnormedout(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)gamma, (float*)beta,(residual != nullptr) ? ((float*)residual + i*n) : (float*)residual ,(float*)bias, ((float*)before_norm_output+ i*n),eps); } } return LayernormOutput({norm_output, params.before_norm_output}); } } } // Due to the cumulative errors caused by using fp16 precision calculations, the fp16 input is first converted to fp32 before using the fp32 kernel. if(norm_type == NormType::rmsnorm){ if (!weights.has_value()) {//In this case, norm_output = input+residual if (data_type == DataType::TYPE_FP32){ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i = 0 ; i<m ; i++){ add_residual_bias_float((float*)norm_output->data()+i*n, (float*)input->data()+i*n, (bool)residual ? (float*)residual+i*n : nullptr, (float*)bias, n); } } else if (data_type == DataType::TYPE_FP16){ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i = 0 ; i<m ; i++){ add_residual_bias_fp16((__fp16*)norm_output->data()+i*n, (__fp16*)input->data()+i*n, (bool)residual ? (__fp16*)residual+i*n : nullptr, (__fp16*)bias, n); } } else { throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); } return LayernormOutput({norm_output, params.before_norm_output}); } if(data_type == DataType::TYPE_FP32||data_type == DataType::TYPE_FP16){ // if(!is_output && (!before_norm_output || before_norm_output != norm_output->data())){//without before_norm_output is_output false if ((data_type == DataType::TYPE_FP32&&(!gamma || std::all_of((float *)gamma, (float *)gamma + n, [](float value) { return value == 1.0f; }))) ||(data_type == DataType::TYPE_FP16&&(!gamma || std::all_of((__fp16 *)gamma, (__fp16 *)gamma + n, [](__fp16 value) { return value == 1.0; })))) { #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if(data_type == DataType::TYPE_FP16){//convert_float_to_fp16 float* input_converted = new float[n]; float* output_converted = new float[n]; float* beta_converted = new float[n]; convert_fp16_to_float((__fp16*)input->data()+i*n,input_converted,n); convert_fp16_to_float((__fp16*)norm_output->data()+i*n,output_converted,n); if(beta) convert_fp16_to_float((__fp16*)beta,beta_converted,n); RMSNorm_Nogamma(n,input_converted, output_converted, beta!= nullptr ? beta_converted : nullptr, eps); convert_float_to_fp16(output_converted,(__fp16*)norm_output->data()+i*n,n); delete[] input_converted; delete[] output_converted; delete[] beta_converted; } else if(data_type == DataType::TYPE_FP32){ RMSNorm_Nogamma(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, beta!= nullptr ? (float*)beta : nullptr, eps); } else throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); } return LayernormOutput({norm_output, params.before_norm_output}); } else{ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if(data_type == DataType::TYPE_FP16){//convert_float_to_fp16 float* input_converted = new float[n]; float* output_converted = new float[n]; float* beta_converted = new float[n]; float* gamma_converted = new float[n]; convert_fp16_to_float((__fp16*)input->data()+i*n,input_converted,n); convert_fp16_to_float((__fp16*)norm_output->data()+i*n,output_converted,n); if(beta) convert_fp16_to_float((__fp16*)beta,beta_converted,n); convert_fp16_to_float((__fp16*)gamma,gamma_converted,n); RMSNorm(n, input_converted, output_converted, gamma_converted, beta!= nullptr ? beta_converted : nullptr, eps); convert_float_to_fp16(output_converted,(__fp16*)norm_output->data()+i*n,n); delete[] input_converted; delete[] output_converted; delete[] gamma_converted; delete[] beta_converted; } else if(data_type == DataType::TYPE_FP32){ // beta!= nullptr ? (float*)beta : nullptr RMSNorm(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*) gamma ,beta!= nullptr ? (float*)beta : nullptr, eps); } else throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); } return LayernormOutput({norm_output, params.before_norm_output}); } } else{ if ((data_type == DataType::TYPE_FP32&&(!gamma || std::all_of((float *)gamma, (float *)gamma + n, [](float value) { return value == 1.0f; }))) ||(data_type == DataType::TYPE_FP16&&(!gamma || std::all_of((__fp16 *)gamma, (__fp16 *)gamma + n, [](__fp16 value) { return value == 1.0; })))) { #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if(data_type == DataType::TYPE_FP16){ float* input_converted = new float[n]; float* output_converted = new float[n]; float* beta_converted = new float[n]; float* bias_converted = new float[n]; float* before_norm_output_converted = new float[n]; float* residual_converted = new float[n]; convert_fp16_to_float((__fp16*)input->data()+i*n,input_converted,n); convert_fp16_to_float((__fp16*)norm_output->data()+i*n,output_converted,n); if(beta) convert_fp16_to_float((__fp16*)beta,beta_converted,n); if(before_norm_output && before_norm_output != norm_output->data())convert_fp16_to_float((__fp16*)before_norm_output+i*n,before_norm_output_converted,n); if(residual)convert_fp16_to_float((__fp16*)residual + i*n,residual_converted,n); if(bias)convert_fp16_to_float((__fp16*)bias + i*n,bias_converted,n); RMSNorm_Nogamma_isoutput(n,\ (before_norm_output && before_norm_output != norm_output->data())? before_norm_output_converted : nullptr,\ input_converted,\ output_converted,\ beta!= nullptr ? beta_converted : nullptr,\ (residual != nullptr) ? residual_converted : nullptr,\ (bias != nullptr) ? bias_converted:nullptr,\ eps); convert_float_to_fp16(output_converted,(__fp16*)norm_output->data()+i*n,n); convert_float_to_fp16(before_norm_output_converted,(__fp16*)before_norm_output+i*n,n); delete[] input_converted; delete[] output_converted; delete[] beta_converted ; delete[] bias_converted ; delete[] before_norm_output_converted ; delete[] residual_converted; } else{ RMSNorm_Nogamma_isoutput(n,\ (before_norm_output && before_norm_output != norm_output->data())?(float*)before_norm_output+i*n : nullptr,\ (float*)input->data()+i*n, \ (float*)norm_output->data()+i*n, \ (float*)beta, \ (residual != nullptr) ? (float*)residual + i*n : nullptr,\ (bias != nullptr) ?(float*)bias:nullptr, \ eps); } } return LayernormOutput({norm_output, params.before_norm_output}); } else{ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ if(data_type == DataType::TYPE_FP16){ float* input_converted = new float[n]; float* output_converted = new float[n]; float* gamma_converted = new float[n]; float* beta_converted = new float[n]; float* bias_converted = new float[n]; float* before_norm_output_converted = new float[n]; float* residual_converted = new float[n]; convert_fp16_to_float((__fp16*)input->data()+i*n,input_converted,n); convert_fp16_to_float((__fp16*)gamma,gamma_converted,n); if(beta) convert_fp16_to_float((__fp16*)beta,beta_converted,n); if(before_norm_output && before_norm_output != norm_output->data()) convert_fp16_to_float((__fp16*)before_norm_output+i*n,before_norm_output_converted,n); if(residual)convert_fp16_to_float((__fp16*)residual + i*n,residual_converted,n); if(bias)convert_fp16_to_float((__fp16*)bias,bias_converted,n); RMSNorm_isoutput(n,\ (before_norm_output && before_norm_output != norm_output->data())? before_norm_output_converted : nullptr,\ input_converted,\ output_converted,\ gamma_converted,\ (beta!= nullptr) ? beta_converted : nullptr,\ (residual != nullptr) ? residual_converted : nullptr,\ (bias != nullptr) ? bias_converted:nullptr,\ eps); convert_float_to_fp16(output_converted,(__fp16*)norm_output->data()+i*n,n); if(before_norm_output && before_norm_output != norm_output->data()) convert_float_to_fp16(before_norm_output_converted,(__fp16*)before_norm_output+i*n,n); delete[] input_converted; delete[] output_converted; delete[] gamma_converted; delete[] beta_converted; delete[] bias_converted; delete[] before_norm_output_converted; delete[] residual_converted; } else{ float* before_norm_output_converted = new float[n]; RMSNorm_isoutput(n,\ (before_norm_output && before_norm_output != norm_output->data())?before_norm_output_converted : nullptr,\ (float*)input->data()+i*n, \ (float*)norm_output->data()+i*n, \ (float*) gamma ,\ (beta!= nullptr) ? (float*)beta : nullptr, \ (residual != nullptr) ? (float*)residual + i*n : (float*)residual,\ (bias != nullptr) ? (float*)bias:nullptr,\ eps); if(before_norm_output && before_norm_output != norm_output->data()) std::memcpy((float*)before_norm_output+i*n, before_norm_output_converted, n * sizeof(float)); delete[] before_norm_output_converted; } } return LayernormOutput({norm_output, params.before_norm_output}); } } } else throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); } // ********************************************** // before_norm_output params.return_norm_output bias/residual exist // . . F // layernorm(input)->normed_output // F . T // layernorm(input+bias+residual)->normed_output // T T T // layernorm(input+bias+residual)->before_norm_output // layernorm(input+bias+residual)->normed_output // T F T // (input+bias+residual)->before_norm_output // layernorm(input+bias+residual)->normed_output // ********************************************** else if (norm_type == NormType::layernorm && data_type == DataType::TYPE_FP32){ if(!is_output){//. . F if (!gamma || std::all_of((float *)gamma, (float *)gamma + n, [](float value) { return value == 1.0f; })){ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ layerNorm_Nogamma(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)beta, eps); } return LayernormOutput({norm_output, params.before_norm_output}); }//(gamma =1,1......)OR (no gamma) else{ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ layerNorm(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)gamma, (float*)beta, eps); } return LayernormOutput({norm_output, params.before_norm_output}); } } else if(!before_norm_output){//add bias residual //F . T if (!gamma || std::all_of((float *)gamma, (float *)gamma + n, [](float value) { return value == 1.0f; })){ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ layerNorm_Nogamma_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)beta,(residual != nullptr) ? (float*)residual + i*n : (float*)residual,(float*)bias, eps); } return LayernormOutput({norm_output, params.before_norm_output}); }//(gamma =1,1......)OR (no gamma) else{ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ layerNorm_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)gamma, (float*)beta,(residual != nullptr) ? ((float*)residual + i*n) : nullptr ,(float*)bias, eps); } return LayernormOutput({norm_output, params.before_norm_output}); } } else if(params.return_normed_output){// T T T if (!gamma || std::all_of((float *)gamma, (float *)gamma + n, [](float value) { return value == 1.0f; })){ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ layerNorm_Nogamma_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)beta,(residual != nullptr) ? ((float*)residual + i*n) : (float*)residual,(float*)bias, eps); if(before_norm_output != norm_output->data())std::memcpy((float*)before_norm_output + i*n,(float*)norm_output->data() + i*n,n * sizeof(float)); } return LayernormOutput({norm_output, params.before_norm_output}); }//gamma =1,1...... No gamma else{ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ layerNorm_isoutput(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)gamma, (float*)beta,(residual != nullptr) ? ((float*)residual + i*n) : (float*)residual ,(float*)bias, eps); if(before_norm_output != norm_output->data()) std::memcpy((float*)before_norm_output + i*n,(float*)norm_output->data() + i*n,n * sizeof(float)); } return LayernormOutput({norm_output, params.before_norm_output}); } } else{ //T F T if (!gamma || std::all_of((float *)gamma, (float *)gamma + n, [](float value) { return value == 1.0f; })){ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ layerNorm_Nogamma_isoutput_unnormedout(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)beta,(residual != nullptr) ? ((float*)residual + i*n) : (float*)residual,(float*)bias, ((float*)before_norm_output + i*n), eps); } return LayernormOutput({norm_output, params.before_norm_output}); }//gamma =1,1...... No gamma else{ #pragma omp parallel for num_threads(std::min(m,numThreads)) if(m>=2) for(int i=0;i<m;i++){ layerNorm_isoutput_unnormedout(n,(float*)input->data()+i*n, (float*)norm_output->data()+i*n, (float*)gamma, (float*)beta,(residual != nullptr) ? ((float*)residual + i*n) : (float*)residual ,(float*)bias, ((float*)before_norm_output+ i*n),eps); } return LayernormOutput({norm_output, params.before_norm_output}); } } } else throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); } }