void RMSNorm_Nogamma_isoutput()

in maga_transformer/cpp/devices/arm_impl/ArmLayerNormOp.cc [235:316]


void RMSNorm_Nogamma_isoutput(int n, float* before_norm_output,const float* input, float* norm_out, 
               const float* beta, float* residual, float* bias, const double  eps){
    float32x4_t square_sum_v[4];
    #pragma unroll
    for (int i = 0; i < 4; ++i) {
        square_sum_v[i] = vdupq_n_f32(0.0f);
    }

    int d = 0;
    for (; d <= n - 16; d += 16) {
        float32x4x4_t regs = vld1q_f32_x4(input + d);
        float32x4x4_t regs_residual_bias;
        if(residual) {
            regs_residual_bias= vld1q_f32_x4(residual + d);
            #pragma unroll
            for (int i = 0; i < 4; ++i) {
                regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]);
            }
        }
        if(bias){
            regs_residual_bias = vld1q_f32_x4(bias + d);
            #pragma unroll
            for (int i = 0; i < 4; ++i) {
                regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]);
            }            
        } 
        if(before_norm_output)vst1q_f32_x4(before_norm_output + d, regs);
        #pragma unroll
        for (int i = 0; i < 4; ++i) {
            square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i]));
        }        
    }
    float32_t square_sum = 0.0f;
    for (; d < n; ++d) {
        float val = input[d];
        if(residual) val+=residual[d];
        if(bias)     val+=bias[d];
        if(before_norm_output) before_norm_output[d] = val;
        square_sum += val * val;
    }

    square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]);
    square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]);
    square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]);
    #pragma unroll
    for (int i = 0; i < 4; ++i) {
        square_sum += square_sum_v[0][i];
    }

    float rms = square_sum / n;
    rms = 1.0f / std::sqrt(rms + eps);
    float32x4_t rms_v = vdupq_n_f32(rms);

    // normalization
    d = 0;
    float32x4x4_t input_v;
    float32x4x4_t beta_v;
    for (; d <= n - 16; d += 16) {
        float32x4x4_t Residual;
        float32x4x4_t Bias;
        input_v = vld1q_f32_x4(input + d);
        if(residual) Residual = vld1q_f32_x4(residual + d);
        if(bias) Bias = vld1q_f32_x4(bias + d);
        if(beta) beta_v = vld1q_f32_x4(beta + d);
    #pragma unroll
        for (int i = 0; i < 4; ++i) {
            if(residual) input_v.val[i] = vaddq_f32(input_v.val[i], Residual.val[i]);
            if(bias)input_v.val[i] = vaddq_f32(input_v.val[i], Bias.val[i]);
            input_v.val[i] = vmulq_f32(input_v.val[i], rms_v);
            if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]);
        }
        vst1q_f32_x4(norm_out + d, input_v);
    }
    float input_residual_bias;
    for (; d < n; ++d) {
        input_residual_bias = input[d];
        if(residual) input_residual_bias += residual[d];
        if(bias) input_residual_bias += bias[d];
        norm_out[d] = input_residual_bias * rms;
        if(beta) norm_out[d] = norm_out[d] + beta[d];
    }
}