in maga_transformer/cpp/devices/arm_impl/ArmLayerNormOp.cc [579:680]
void layerNorm_isoutput(int n,const float* input, float* norm_out, const float* gamma,
const float* beta, float* residual , float* bias, const double eps){
float32x4_t sum_v[4];
#pragma unroll
for (int i = 0; i < 4; ++i) {
sum_v[i] = vdupq_n_f32(0.0f);
}
float32x4_t square_sum_v[4];
#pragma unroll
for (int i = 0; i < 4; ++i) {
square_sum_v[i] = vdupq_n_f32(0.0f);
}
int d = 0;
for (; d <= n - 16; d += 16) {
float32x4x4_t regs = vld1q_f32_x4(input + d);
float32x4x4_t regs_residual_bias;
if(residual) {
regs_residual_bias= vld1q_f32_x4(residual + d);
#pragma unroll
for (int i = 0; i < 4; ++i) {
regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]);
}
}
if(bias){
regs_residual_bias = vld1q_f32_x4(bias + d);
#pragma unroll
for (int i = 0; i < 4; ++i) {
regs.val[i] = vaddq_f32(regs.val[i],regs_residual_bias.val[i]);
}
}
#pragma unroll
for (int i = 0; i < 4; ++i) {
sum_v[i] = vaddq_f32(sum_v[i], regs.val[i]);
square_sum_v[i] = vaddq_f32(square_sum_v[i], vmulq_f32(regs.val[i], regs.val[i]));
}
}
float32_t sum = 0.0f;
float32_t square_sum = 0.0f;
for (; d < n; ++d) {
float val = input[d];
if(residual) val+=residual[d];
if(bias) val+=bias[d];
sum += val;
square_sum += val * val;
}
sum_v[0] = vaddq_f32(sum_v[0], sum_v[1]);
sum_v[2] = vaddq_f32(sum_v[2], sum_v[3]);
sum_v[0] = vaddq_f32(sum_v[0], sum_v[2]);
#pragma unroll
for (int i = 0; i < 4; ++i) {
sum += sum_v[0][i];
}
square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[1]);
square_sum_v[2] = vaddq_f32(square_sum_v[2], square_sum_v[3]);
square_sum_v[0] = vaddq_f32(square_sum_v[0], square_sum_v[2]);
#pragma unroll
for (int i = 0; i < 4; ++i) {
square_sum += square_sum_v[0][i];
}
float mean = sum / n;
float variance = square_sum / n;
variance = 1.0f / std::sqrt(variance - mean * mean + eps);
float32x4_t mean_v = vdupq_n_f32(mean);
float32x4_t variance_v = vdupq_n_f32(variance);
// normalization
d = 0;
float32x4x4_t gamma_v;
float32x4x4_t beta_v;
for (; d <= n - 16; d += 16) {
float32x4x4_t input_v = vld1q_f32_x4(input + d);
float32x4x4_t Residual;
float32x4x4_t Bias;
if(residual) Residual = vld1q_f32_x4(residual + d);
if(bias) Bias = vld1q_f32_x4(bias + d);
gamma_v = vld1q_f32_x4(gamma + d);
if(beta) beta_v = vld1q_f32_x4(beta + d);
#pragma unroll
for (int i = 0; i < 4; ++i) {
input_v.val[i] = vaddq_f32(input_v.val[i], Residual.val[i]);
input_v.val[i] = vaddq_f32(input_v.val[i], Bias.val[i]);
input_v.val[i] = vsubq_f32(input_v.val[i], mean_v);
input_v.val[i] = vmulq_f32(input_v.val[i], gamma_v.val[i]);
input_v.val[i] = vmulq_f32(input_v.val[i], variance_v);
if(beta) input_v.val[i] = vaddq_f32(input_v.val[i], beta_v.val[i]);
}
vst1q_f32_x4(norm_out + d, input_v);
}
float input_residual_bias;
for (; d < n; ++d) {
input_residual_bias = input[d];
if(residual) input_residual_bias += residual[d];
if(bias) input_residual_bias += bias[d];
norm_out[d] = (input_residual_bias - mean) * gamma[d] * variance;
if(beta) norm_out[d] += beta[d];
}
}