in maga_transformer/cpp/devices/arm_impl/ArmLayerNormOp.cc [144:233]
void RMSNorm_isoutput(int n, float* before_norm_output, float* input, float* norm_out, const float* gamma,
const float* beta, float* residual, float* bias, const double eps){
float32x4_t square_sum_v[4];
#pragma unroll
for (int i = 0; i < 4; ++i) {
square_sum_v[i] = vdupq_n_f32(0.0f);
}
int d = 0;
for (; d <= n - 16; d += 16) {
float32x4x4_t regs = vld1q_f32_x4(input + d);
float32x4x4_t regs_residual_bias;
if(residual) {
regs_residual_bias= vld1q_f32_x4(residual + d);
#pragma unroll
for (int i = 0; i < 4; ++i) {
regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]);
}
}
if(bias){
regs_residual_bias = vld1q_f32_x4(bias + d);
#pragma unroll
for (int i = 0; i < 4; ++i) {
regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]);
}
}
if(before_norm_output && before_norm_output != norm_out) vst1q_f32_x4(before_norm_output + d, regs);
#pragma unroll
for (int i = 0; i < 4; ++i) {
//add_bias_residual
square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i]));
}
}
float32_t square_sum = 0.0f;
square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]);
square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]);
square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]);
for (; d < n; ++d) {
float val = input[d];
if(residual) val+=residual[d];
if(bias) val+=bias[d];
if(before_norm_output && before_norm_output != norm_out) before_norm_output[d] = val;
square_sum += val * val;
}
//
#pragma unroll
for (int i = 0; i < 4; ++i) {
square_sum += square_sum_v[0][i];
}
float rms = square_sum / n;
rms = 1.0f / std::sqrt(rms + eps);
float32x4_t rms_v = vdupq_n_f32(rms);
// normalization
d = 0;
float32x4x4_t input_v;
float32x4x4_t gamma_v;
float32x4x4_t beta_v;
for (; d <= n - 16; d += 16) {
float32x4x4_t Residual;
float32x4x4_t Bias;
input_v = vld1q_f32_x4(input + d);
gamma_v = vld1q_f32_x4(gamma + d);
if(residual) Residual = vld1q_f32_x4(residual + d);
if(bias) Bias = vld1q_f32_x4(bias + d);
if(beta) beta_v = vld1q_f32_x4(beta + d);
#pragma unroll
for (int i = 0; i < 4; ++i) {
if(residual) input_v.val[i] = vaddq_f32(input_v.val[i], Residual.val[i]);
if(bias) input_v.val[i] = vaddq_f32(input_v.val[i], Bias.val[i]);
//input_v.val[i] = vmulq_f32(input_v.val[i],scale_v);
input_v.val[i] = vmulq_f32(input_v.val[i], rms_v);
input_v.val[i] = vmulq_f32(input_v.val[i], gamma_v.val[i]);
if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]);
}
vst1q_f32_x4(norm_out + d, input_v);
}
float input_residual_bias;
for (; d < n; ++d) {
input_residual_bias = input[d];
if(residual) input_residual_bias += residual[d];
if(bias) input_residual_bias += bias[d];
norm_out[d] = input_residual_bias * rms * gamma[d];
if(beta) norm_out[d] = norm_out[d] + beta[d];
}
}