in maga_transformer/cpp/devices/arm_impl/ArmAttentionOp.cc [481:703]
void ArmCpuDevice::runOneBatchStride(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step) {
auto datatype = params.input.type();
auto head_num = params.configs.head_num;
auto kv_head_num = params.configs.kv_head_num;
auto size_per_head = params.configs.size_per_head;
std::chrono::steady_clock::time_point tStart, tEnd;
std::chrono::microseconds diff;
if (datatype != DataType::TYPE_FP32) {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
// if (!params.common.kv_cache.has_value()) {
// throw std::runtime_error("kv cache block pointers can not be null");
// }
// Retrieve q/k/v by stride and not to split.
auto qkv = params.input.dataWithOffset(past_seq * (head_num + 2 * kv_head_num) * size_per_head);
printBufferData(params.input, "qkv");
tStart = std::chrono::steady_clock::now();
void *q_array[head_num], *k_array[head_num], *v_array[head_num];
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 0);
tStart = std::chrono::steady_clock::now();
//if (params.weights.qkv_weight->bias) {
if (params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias) {
auto bias_data_type = params.weights.qkv_weight->bias->type();
if (bias_data_type == DataType::TYPE_FP32) {
addQKVBias<float, float>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
} else if (bias_data_type == DataType::TYPE_FP16) {
addQKVBias<float, __fp16>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 1);
printBufferData(params.input, "biased qkv");
tStart = std::chrono::steady_clock::now();
if (params.configs.rope_config.style != RopeStyle::No) {
if (params.configs.rope_config.style == RopeStyle::Base) {
halfRopeQK<float>(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step,
params.configs.rope_config.base,
params.configs.rope_config.dim);
} else {
throw std::runtime_error("SelfAttention RoPE type is not supported");
}
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 2);
printBufferData(params.input, "roped qkv");
tStart = std::chrono::steady_clock::now();
if (params.common.kv_cache.has_value()) {
updateKVCacheStride<float>(params, qkv, batch, seq_len, step);
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 8);
BufferPtr k_buffer, v_buffer;
void *k_in, *v_in;
int stride_kv;
if (step == 0) {
/* Context */
k_in = (float *)qkv + head_num * size_per_head;
v_in = (float *)qkv + (head_num + kv_head_num) * size_per_head;
stride_kv = (head_num + 2 * kv_head_num) * size_per_head;
step = seq_len - 1; // Trick to unify context and decoder processes.
} else {
/* Decoder */
k_buffer = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});
v_buffer = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});
assemCache(params, batch, k_buffer, v_buffer);
printBufferData(*k_buffer, "k_buffer");
printBufferData(*v_buffer, "v_buffer");
k_in = k_buffer->data();
v_in = v_buffer->data();
stride_kv = kv_head_num * size_per_head;
}
tStart = std::chrono::steady_clock::now();
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 3);
tStart = std::chrono::steady_clock::now();
/* Re-init arrary as per [batch, num_heads]. */
getPerHeadArray<float>(qkv, k_in, v_in, q_array, k_array, v_array,
head_num, kv_head_num, size_per_head);
auto qk_output = allocateBuffer({datatype, {1, head_num, seq_len, step + 1}, AllocationType::HOST}, {"qk_output"});
const int stride_q = (head_num + 2 * kv_head_num) * size_per_head;
//const int N = 1 * head_num;
//parallel_for(N, [&](int tid) {
const int MHA_HEADS = 1 * head_num;
parallel_for(MHA_HEADS, [&](int tid) {
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, seq_len, step + 1, size_per_head, 1.0,
(const float*)q_array[tid], stride_q,
(const float*)k_array[tid], stride_kv, 0.0,
(float*)qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1);
});
printBufferData(*qk_output, "qk_output");
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 4);
tStart = std::chrono::steady_clock::now();
float scale = (1.0f / sqrtf(size_per_head * 1.0f));
BufferPtr softmax_qk_output;
if (seq_len == 1) {
/* Decoder */
softmax_qk_output = softmax({qk_output, std::nullopt, std::nullopt, scale});
} else {
/* Context */
RUNTIME_ASSERT_OP_ARG(params.common.attention_mask,
"attention_mask must be provided for default context attention implementation");
auto attention_mask = (*params.common.attention_mask).view(batch, 1);
printBufferData(attention_mask, "attention_mask");
softmax_qk_output = softmax({qk_output, attention_mask, std::nullopt, scale});
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 5);
printBufferData(*softmax_qk_output, "softmax_qk_output");
tStart = std::chrono::steady_clock::now();
//const int NN = 1 * head_num;
//parallel_for(NN, [&](int tid) {
// cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, seq_len, size_per_head, step + 1, 1.0,
// (const float*)softmax_qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1,
// (const float*)v_array[tid], stride_kv, 0.0,
// (float*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), head_num * size_per_head);
//});
if (!isKAIenabled) {
parallel_for(MHA_HEADS, [&](int tid) {
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, seq_len, size_per_head, step + 1, 1.0,
(const float*)softmax_qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1,
(const float*)v_array[tid], stride_kv, 0.0,
(float*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), head_num * size_per_head);
});
} else {
if (seq_len == 1) {
/* Decoder has higher performance with cblas for gemm. */
parallel_for(MHA_HEADS, [&](int tid) {
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, seq_len, size_per_head, step + 1, 1.0,
(const float*)softmax_qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1,
(const float*)v_array[tid], stride_kv, 0.0,
(float*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), head_num * size_per_head);
});
} else {
/* Context has higher performance with KleidiAI for gemm. */
const size_t bias_size = size_per_head;
float* bias = new float[bias_size];
memset(bias, 0, bias_size * sizeof(float));
const size_t M = seq_len;
const size_t N = size_per_head;
const size_t K = step + 1;
const size_t mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon(M, K, mr, kr, sr);
const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(N, K, nr, kr);
const size_t lhs_stride = K * sizeof(float);
const size_t rhs_stride = stride_kv * sizeof(float);
const size_t dst_stride_row = (head_num * size_per_head) * sizeof(float);
const size_t dst_stride_col = sizeof(float);
parallel_for(MHA_HEADS, [&](int tid) {
uint8_t *lhs_packed = new uint8_t[lhs_packed_size];
uint8_t *rhs_packed = new uint8_t[rhs_packed_size];
kai_run_lhs_quant_pack_bf16p8x4_f32_neon(M, K, mr, kr, sr, 0,
(const void*)softmax_qk_output->dataWithOffset(tid * M * K), lhs_stride, lhs_packed);
kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(1,
N,
K,
nr,
kr,
sr,
rhs_stride,
(const void*)v_array[tid],
bias,
NULL,
rhs_packed,
0,
NULL);
kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(M, N, K,
lhs_packed,
rhs_packed,
(void*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), dst_stride_row, dst_stride_col,
-FLT_MAX, FLT_MAX);
delete[] rhs_packed;
delete[] lhs_packed;
});
delete[] bias;
}
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 6);
tStart = std::chrono::steady_clock::now();
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 7);
/* Print profile data at the end of operator unit test. */
// if (a_cnt_[0] == 24)
// printStat();
}