void RMSNorm_isoutput()

in maga_transformer/cpp/devices/arm_impl/ArmLayerNormOp.cc [144:233]


void RMSNorm_isoutput(int n, float* before_norm_output, float* input, float* norm_out, const float* gamma,
               const float* beta, float* residual, float* bias, const double  eps){
    float32x4_t square_sum_v[4];
    #pragma unroll
    for (int i = 0; i < 4; ++i) {
        square_sum_v[i] = vdupq_n_f32(0.0f);
    }
    
    int d = 0;
    for (; d <= n - 16; d += 16) {
        float32x4x4_t regs = vld1q_f32_x4(input + d);
        float32x4x4_t regs_residual_bias;
        if(residual) {
            regs_residual_bias= vld1q_f32_x4(residual + d);
            #pragma unroll
            for (int i = 0; i < 4; ++i) {
                regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]);
            }
        }
        if(bias){
            regs_residual_bias = vld1q_f32_x4(bias + d);
            #pragma unroll
            for (int i = 0; i < 4; ++i) {
                regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]);
            }            
        } 
        if(before_norm_output && before_norm_output != norm_out) vst1q_f32_x4(before_norm_output + d, regs);
        #pragma unroll
        for (int i = 0; i < 4; ++i) {
            //add_bias_residual
            square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i]));
        }        
    }
    
    float32_t square_sum = 0.0f;


    square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]);
    square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]);
    square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]);
    for (; d < n; ++d) {
        float val = input[d];
        if(residual) val+=residual[d];
        if(bias)     val+=bias[d];
        if(before_norm_output && before_norm_output != norm_out) before_norm_output[d] = val;
        square_sum += val * val;
    }
//
    #pragma unroll
    for (int i = 0; i < 4; ++i) {
        square_sum += square_sum_v[0][i];
    }
    
    float rms = square_sum / n;
    rms = 1.0f / std::sqrt(rms + eps);
    float32x4_t rms_v = vdupq_n_f32(rms);
    // normalization
    d = 0;
    float32x4x4_t input_v;
    float32x4x4_t gamma_v;
    float32x4x4_t beta_v;
    for (; d <= n - 16; d += 16) {
        float32x4x4_t Residual;
        float32x4x4_t Bias;
        input_v = vld1q_f32_x4(input + d);
        gamma_v = vld1q_f32_x4(gamma + d);
        if(residual) Residual = vld1q_f32_x4(residual + d);
        if(bias) Bias = vld1q_f32_x4(bias + d);
        if(beta) beta_v = vld1q_f32_x4(beta + d);
    #pragma unroll
        for (int i = 0; i < 4; ++i) {            
            if(residual) input_v.val[i] = vaddq_f32(input_v.val[i], Residual.val[i]);
            if(bias)     input_v.val[i] = vaddq_f32(input_v.val[i], Bias.val[i]);
            //input_v.val[i] = vmulq_f32(input_v.val[i],scale_v);
            input_v.val[i] = vmulq_f32(input_v.val[i], rms_v);
            input_v.val[i] = vmulq_f32(input_v.val[i], gamma_v.val[i]);
            if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]);
        }
        vst1q_f32_x4(norm_out + d, input_v);
    }
    float input_residual_bias;
    for (; d < n; ++d) {
        input_residual_bias = input[d];
        if(residual)  input_residual_bias += residual[d];
        if(bias)      input_residual_bias += bias[d];
        norm_out[d] = input_residual_bias  * rms * gamma[d];
        if(beta) norm_out[d] = norm_out[d] + beta[d];
    }

}