in maga_transformer/cpp/devices/arm_impl/ArmAttentionOp.cc [189:406]
void ArmCpuDevice::runOneBatch(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step) {
auto datatype = params.input.type();
auto head_num = params.configs.head_num;
auto kv_head_num = params.configs.kv_head_num;
auto size_per_head = params.configs.size_per_head;
// if (!params.common.kv_cache.has_value()) {
// throw std::runtime_error("kv cache block pointers can not be null");
// }
std::chrono::steady_clock::time_point tStart, tEnd;
std::chrono::microseconds diff;
// qkv to q_output, k_output, v_output
auto qkv = params.input.dataWithOffset(past_seq * (head_num + 2 * kv_head_num) * size_per_head);
printBufferData(params.input, "qkv");
tStart = std::chrono::steady_clock::now();
//if (params.weights.qkv_weight->bias) {
if (params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias) {
if (datatype == DataType::TYPE_FP32) {
auto bias_data_type = params.weights.qkv_weight->bias->type();
if (bias_data_type == DataType::TYPE_FP32) {
addQKVBias<float, float>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
} else if (bias_data_type == DataType::TYPE_FP16) {
addQKVBias<float, __fp16>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
} else if (datatype == DataType::TYPE_FP16) {
addQKVBias<__fp16, __fp16>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 1);
printBufferData(params.input, "biased qkv");
tStart = std::chrono::steady_clock::now();
if (params.configs.rope_config.style != RopeStyle::No) {
if (params.configs.rope_config.style == RopeStyle::Base) {
if (datatype == DataType::TYPE_FP32) {
halfRopeQK<float>(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step,
params.configs.rope_config.base, params.configs.rope_config.dim);
} else if (datatype == DataType::TYPE_FP16) {
halfRopeQK<__fp16>(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step,
params.configs.rope_config.base, params.configs.rope_config.dim);
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
} else {
throw std::runtime_error("SelfAttention RoPE type is not supported");
}
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 2);
printBufferData(params.input, "roped qkv");
tStart = std::chrono::steady_clock::now();
arm_compute::DataType acl_data_type = getAclDataType(datatype);
arm_compute::NESplit split;
arm_compute::Tensor src;
arm_compute::TensorInfo src_info = arm_compute::TensorInfo(
arm_compute::TensorShape(size_per_head, head_num + 2 * kv_head_num, seq_len, 1), 1, acl_data_type);
src.allocator()->init(src_info);
src.allocator()->import_memory(qkv);
std::vector<arm_compute::Tensor> dsts{};
std::vector<arm_compute::ITensor*> dsts_ptr;
arm_compute::TensorInfo q_info, kv_info;
arm_compute::Tensor q, k, v;
q_info = arm_compute::TensorInfo(
arm_compute::TensorShape(size_per_head, head_num, seq_len, 1), 1, acl_data_type);
kv_info = arm_compute::TensorInfo(
arm_compute::TensorShape(size_per_head, kv_head_num, seq_len, 1), 1, acl_data_type);
auto q_input = allocateBuffer({datatype, {1, seq_len, head_num, size_per_head}, AllocationType::HOST}, {});
auto k_input = allocateBuffer({datatype, {1, seq_len, kv_head_num, size_per_head}, AllocationType::HOST}, {});
auto v_input = allocateBuffer({datatype, {1, seq_len, kv_head_num, size_per_head}, AllocationType::HOST}, {});
q.allocator()->init(q_info);
k.allocator()->init(kv_info);
v.allocator()->init(kv_info);
q.allocator()->import_memory(q_input->data());
dsts.push_back(std::move(q));
k.allocator()->import_memory(k_input->data());
dsts.push_back(std::move(k));
v.allocator()->import_memory(v_input->data());
dsts.push_back(std::move(v));
for (auto& dst : dsts) {
dsts_ptr.emplace_back(&dst);
}
split.configure(&src, dsts_ptr, 1);
split.run();
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 0);
printBufferData(*q_input, "q_input");
printBufferData(*k_input, "k_input");
printBufferData(*v_input, "v_input");
tStart = std::chrono::steady_clock::now();
if (params.common.kv_cache.has_value()) {
updateKVCache(params, batch, step, k_input, v_input);
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 8);
BufferPtr k_in, k_out, v_in, v_out;
if (step == 0) {
/* Context. */
k_in = k_input;
v_in = v_input;
k_out = allocateBuffer({datatype, {1, kv_head_num, seq_len, size_per_head}, AllocationType::HOST}, {});
v_out = allocateBuffer({datatype, {1, kv_head_num, seq_len, size_per_head}, AllocationType::HOST}, {});
} else {
/* Decoder. Retrieve k/v cache data. */
k_in = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});
v_in = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});
assemCache(params, batch, k_in, v_in);
printBufferData(*k_in, "k_in");
printBufferData(*v_in, "v_in");
k_out = allocateBuffer({datatype, {1, kv_head_num, step + 1, size_per_head}, AllocationType::HOST}, {});
v_out = allocateBuffer({datatype, {1, kv_head_num, step + 1, size_per_head}, AllocationType::HOST}, {});
}
tStart = std::chrono::steady_clock::now();
auto q_output = allocateBuffer({datatype, {1, head_num, seq_len, size_per_head}, AllocationType::HOST}, {});
transposeDim12(std::move(q_input), q_output->data());
transposeDim12(std::move(k_in), k_out->data());
transposeDim12(std::move(v_in), v_out->data());
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 3);
printBufferData(*q_output, "q_output");
printBufferData(*k_out, "k_out");
printBufferData(*v_out, "v_out");
if (kv_head_num != head_num) {
/* repeat K/V */
size_t len;
if (step == 0) {
len = seq_len;
} else {
len = step + 1;
}
auto k_repeat = allocateBuffer({datatype, {1, head_num, len, size_per_head}, AllocationType::HOST}, {});
auto v_repeat = allocateBuffer({datatype, {1, head_num, len, size_per_head}, AllocationType::HOST}, {});
auto n_rep = head_num / kv_head_num;
const int N = kv_head_num;
parallel_for(N, [&](int tid) {
for (int i = 0; i < n_rep; i++) {
memcpy(k_repeat->dataWithOffset((tid * n_rep + i) * len * size_per_head), k_out->dataWithOffset(tid * len * size_per_head), k_out->sizeBytes() / kv_head_num);
memcpy(v_repeat->dataWithOffset((tid * n_rep + i) * len * size_per_head), v_out->dataWithOffset(tid * len * size_per_head), v_out->sizeBytes() / kv_head_num);
}
});
k_out = std::move(k_repeat);
v_out = std::move(v_repeat);
}
tStart = std::chrono::steady_clock::now();
auto qk_output = gemm_acl({*q_output,
*k_out,
std::nullopt,
nullptr,
DataType::TYPE_INVALID,
TransposeOperation::NONE,
TransposeOperation::TRANSPOSE});
printBufferData(*qk_output, "qk_output: ");
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 4);
printBufferData(*qk_output, "qk_output");
tStart = std::chrono::steady_clock::now();
float scale = (1.0f / sqrtf(size_per_head * 1.0f));
BufferPtr softmax_qk_output;
if (seq_len == 1) {
/* Decoder */
softmax_qk_output = softmax({qk_output, std::nullopt, std::nullopt, scale});
} else {
/* Context */
auto attention_mask = (*params.common.attention_mask).view(batch, 1);
softmax_qk_output = softmax({qk_output, attention_mask, std::nullopt, scale});
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 5);
printBufferData(*softmax_qk_output, "softmax_qk_output");
tStart = std::chrono::steady_clock::now();
auto qkv_output = gemm_acl({*softmax_qk_output, *v_out});
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 6);
tStart = std::chrono::steady_clock::now();
transposeDim12(qkv_output, params.output.dataWithOffset(past_seq * head_num * size_per_head));
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 7);
}