void layerNorm_Nogamma_isoutput()

in maga_transformer/cpp/devices/arm_impl/ArmLayerNormOp.cc [682:780]


void layerNorm_Nogamma_isoutput(int n,const float* input, float* norm_out, 
               const float* beta, float* residual , float* bias, const double  eps){ 
    float32x4_t sum_v[4];
    #pragma unroll
    for (int i = 0; i < 4; ++i) {
        sum_v[i] = vdupq_n_f32(0.0f);
    }
    float32x4_t square_sum_v[4];
    #pragma unroll
    for (int i = 0; i < 4; ++i) {
        square_sum_v[i] = vdupq_n_f32(0.0f);
    }

    int d = 0;
    for (; d <= n - 16; d += 16) {
        float32x4x4_t regs = vld1q_f32_x4(input + d);
        float32x4x4_t regs_residual_bias;

        if(residual) {
            regs_residual_bias= vld1q_f32_x4(residual + d);
            #pragma unroll
            for (int i = 0; i < 4; ++i) {
                regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]);
            }
        }
        if(bias){
            regs_residual_bias = vld1q_f32_x4(bias + d);
            #pragma unroll
            for (int i = 0; i < 4; ++i) {
                regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]);
            }            
        } 
        
        #pragma unroll
        for (int i = 0; i < 4; ++i) {
            sum_v[i] = vaddq_f32(sum_v[i], regs.val[i]);
            square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i]));
        }
    }
    float32_t sum = 0.0f;
    float32_t square_sum = 0.0f;
    for (; d < n; ++d) {
        float val = input[d];
        if(residual) val+=residual[d];
        if(bias) val+=bias[d];
        sum += val;
        square_sum += val * val;
    }
    sum_v[0] = vaddq_f32(sum_v[0], sum_v[1]);
    sum_v[2] = vaddq_f32(sum_v[2], sum_v[3]);
    sum_v[0] = vaddq_f32(sum_v[0], sum_v[2]);
    #pragma unroll
    for (int i = 0; i < 4; ++i) {
        sum += sum_v[0][i];
    }

    square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]);
    square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]);
    square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]);
    #pragma unroll
    for (int i = 0; i < 4; ++i) {
        square_sum += square_sum_v[0][i];
    }

    float mean = sum / n;
    float variance = square_sum / n;
    variance = 1.0f / std::sqrt(variance - mean * mean + eps);
    float32x4_t mean_v = vdupq_n_f32(mean);
    float32x4_t variance_v = vdupq_n_f32(variance);

    // normalization
    d = 0;
    float32x4x4_t beta_v;
    for (; d <= n - 16; d += 16) {
        float32x4x4_t input_v = vld1q_f32_x4(input + d);
        float32x4x4_t Residual;
        float32x4x4_t Bias;
        if(residual) Residual = vld1q_f32_x4(residual + d);
        if(bias) Bias = vld1q_f32_x4(bias + d);
        if(beta) beta_v = vld1q_f32_x4(beta + d);
    #pragma unroll
        for (int i = 0; i < 4; ++i) {
            input_v.val[i] = vaddq_f32(input_v.val[i], Residual.val[i]);
            input_v.val[i] = vaddq_f32(input_v.val[i], Bias.val[i]);
            input_v.val[i] = vsubq_f32(input_v.val[i], mean_v);
            input_v.val[i] = vmulq_f32(input_v.val[i], variance_v);
            if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]);
        }
        vst1q_f32_x4(norm_out + d, input_v);
    }
    float input_residual_bias;
    for (; d < n; ++d) {
        input_residual_bias = input[d];
        if(residual) input_residual_bias += residual[d];
        if(bias) input_residual_bias += bias[d];
        norm_out[d] = (input_residual_bias - mean) * variance;
        if(beta)  norm_out[d] += beta[d];
    }
}