maga_transformer/cpp/devices/arm_impl/ArmAttentionOp.cc (610 lines of code) (raw):
#include "maga_transformer/cpp/devices/arm_impl/ArmDevice.h"
#include "maga_transformer/cpp/devices/DeviceFactory.h"
#include "maga_transformer/cpp/core/allocator.h"
#include "maga_transformer/cpp/core/cpu_allocator.h"
#include "maga_transformer/cpp/devices/utils/DebugUtils.h"
#include <openblas/cblas.h>
#include <cfloat>
#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h"
#include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h"
#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.h"
namespace rtp_llm {
/* Input has shape [dim0, dim1, dim2, dim3] */
void transposeDim12(BufferPtr input, void* output) {
auto dim = input->shape();
auto elem_sz = input->typeSize();
for (int k = 0; k < dim[0]; k++) {
for (int i = 0; i < dim[2]; i++) {
for (int j = 0; j < dim[1]; j++) {
memcpy((char*)output + elem_sz * (k * dim[1] * dim[2] * dim[3] + (i * dim[1] + j) * dim[3]),
input->dataWithOffset(k * dim[1] * dim[2] * dim[3] + (j * dim[2] + i) * dim[3]),
elem_sz * dim[3]);
}
}
}
}
void getCacheAddrFromIndex(const KvCacheInfo& kv_cache, size_t batch, size_t block_idx, void **k_addr, void **v_addr) {
const auto& kv_blocks_offset = *(kv_cache.kv_cache_block_id);
const auto& k_cache = *(kv_cache.k_cache_buffer);
const auto& v_cache = *(kv_cache.v_cache_buffer);
const auto max_blocks_per_batch = kv_blocks_offset.shape()[1];
size_t block_size = k_cache[0].sizeBytes();
int *index = (int *)kv_blocks_offset.data();
*k_addr = (char*)k_cache.data() + index[batch * max_blocks_per_batch + block_idx] * block_size;
*v_addr = (char*)v_cache.data() + index[batch * max_blocks_per_batch + block_idx] * block_size;
}
void assemCache(const AttentionModuleParams& params, int batch, BufferPtr k_out, BufferPtr v_out) {
auto elem_sz = k_out->typeSize();
auto kv_seq_len = k_out->shape()[1];
auto head_num = k_out->shape()[2];
auto head_dim = k_out->shape()[3];
auto tokens_per_block = params.configs.tokens_per_block;
size_t blocks_per_batch = (kv_seq_len + tokens_per_block - 1) / tokens_per_block;
size_t copied_len = 0;
void *k_block_addr;
void *v_block_addr;
for (int i = 0; i < blocks_per_batch; i++) {
size_t len = std::min(tokens_per_block, kv_seq_len - copied_len);
getCacheAddrFromIndex(params.common.kv_cache.value(), batch, i, &k_block_addr, &v_block_addr);
memcpy(k_out->dataWithOffset(i * tokens_per_block * head_num * head_dim),
k_block_addr,
elem_sz * len * head_num * head_dim);
memcpy(v_out->dataWithOffset(i * tokens_per_block * head_num * head_dim),
v_block_addr,
elem_sz * len * head_num * head_dim);
copied_len += len;
}
}
/* Input 'qkv' consists of q & k & v, and each with shape [batch, seq_len, num_heads, head_dim]
* 'bias' has shape [num_heads * head_dim]
*/
template<typename Ti, typename Tb>
void addQKVBias(void* qkv, const void* bias,
int batch_sz, int seq_len, int num_heads, int kv_num_heads, int head_size) {
const int N = batch_sz * seq_len;
parallel_for(N, [&](int tid) {
Ti* qkv_input = (Ti*)qkv + tid * (num_heads + 2 * kv_num_heads) * head_size;
for (int i = 0; i < (num_heads + 2 * kv_num_heads) * head_size; i++) {
qkv_input[i] += ((Tb*)bias)[i];
}
});
}
void updateKVCache(const AttentionModuleParams& params, int batch, size_t step, BufferPtr k, BufferPtr v) {
size_t seq_len = k->shape()[1];
auto kv_head_num = params.configs.kv_head_num;
auto size_per_head = params.configs.size_per_head;
auto block_tokens = params.configs.tokens_per_block;
size_t block_num = (seq_len + block_tokens - 1) / block_tokens;
size_t block_offset = step / block_tokens;
auto elem_sz = params.input.typeSize();
size_t copied_len = 0;
void *k_block_addr;
void *v_block_addr;
for (int i = 0; i < block_num; i++) {
size_t len = std::min(block_tokens, seq_len - copied_len);
getCacheAddrFromIndex(params.common.kv_cache.value(), batch, i + block_offset, &k_block_addr, &v_block_addr);
memcpy((uint8_t*)k_block_addr + (step % block_tokens) * kv_head_num * size_per_head * elem_sz,
k->dataWithOffset(i * block_tokens * kv_head_num * size_per_head),
elem_sz * len * kv_head_num * size_per_head);
memcpy((uint8_t*)v_block_addr + (step % block_tokens) * kv_head_num * size_per_head * elem_sz,
v->dataWithOffset(i * block_tokens * kv_head_num * size_per_head),
elem_sz * len * kv_head_num * size_per_head);
copied_len += len;
}
}
/* Input 'qkv' consists of q & k & v, and each with shape [batch, seq_len, num_heads, head_dim].
* Half RoPE is applied to q & k.
* Retrieve pre-calculated Cos/Sin if exists.
*/
template<typename T>
void ArmCpuDevice::halfRopeQK(void *qkv, int batch, int seq_len, int num_heads, int kv_num_heads, int head_size, size_t step, size_t base, size_t embed_dim) {
size_t inv_freq_size = (embed_dim + 1) / 2;
auto &value = ropeCosSin[base];
int calced_seq = std::get<0>(value);
float *cur_cos = std::get<1>(value);
float *cur_sin = std::get<2>(value);
const int N = batch * seq_len;
parallel_for(N, [&](int tid) {
int j = tid % seq_len;
T* q_input = (T*)qkv + tid * (num_heads + 2 * kv_num_heads) * head_size;
T* k_input = (T*)qkv + tid * (num_heads + 2 * kv_num_heads) * head_size + num_heads * head_size;
size_t seq = (j == 0)? step : j;
for (int h = 0; h < num_heads; h++) {
for (int d = 0; d < inv_freq_size; d++) {
float fcr, fci;
if (seq < calced_seq) {
fcr = cur_cos[seq * inv_freq_size + d];
fci = cur_sin[seq * inv_freq_size + d];
} else {
float freq = 1.0f / powf(base, (float)(d * 2) / embed_dim);
float val = freq * seq;
fcr = cosf(val);
fci = sinf(val);
}
auto v0 = q_input + h * head_size + d;
auto v1 = q_input + h * head_size + d + inv_freq_size;
auto d0 = *v0;
auto d1 = *v1;
*v0 = d0 * fcr - d1 * fci;
*v1 = d0 * fci + d1 * fcr;
if (h < kv_num_heads) {
auto v2 = k_input + h * head_size + d;
auto v3 = k_input + h * head_size + d + inv_freq_size;
auto d2 = *v2;
auto d3 = *v3;
*v2 = d2 * fcr - d3 * fci;
*v3 = d2 * fci + d3 * fcr;
}
}
}
});
}
void ArmCpuDevice::printStat() {
for (int i = 0; i < sizeof(a_cnt_) / sizeof(uint64_t); i++) {
std::cout << "$$$ [" << i << "] time (us) - min: " << a_tmin_[i] << " ; max: " << a_tmax_[i] << " ; ave: " << a_tave_[i] / a_cnt_[i] << std::endl;
}
for (int i = 0; i < sizeof(a_cnt_) / sizeof(uint64_t); i++) {
a_tmin_[i] = 999999999;
a_tmax_[i] = 0;
a_tave_[i] = 0;
a_cnt_[i] = 0;
}
}
void ArmCpuDevice::logTime(std::chrono::microseconds diff, size_t index) {
if (diff.count() < a_tmin_[index])
a_tmin_[index] = diff.count();
if (diff.count() > a_tmax_[index])
a_tmax_[index] = diff.count();
a_tave_[index] += diff.count();
a_cnt_[index] += 1;
}
void ArmCpuDevice::runOneBatch(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step) {
auto datatype = params.input.type();
auto head_num = params.configs.head_num;
auto kv_head_num = params.configs.kv_head_num;
auto size_per_head = params.configs.size_per_head;
// if (!params.common.kv_cache.has_value()) {
// throw std::runtime_error("kv cache block pointers can not be null");
// }
std::chrono::steady_clock::time_point tStart, tEnd;
std::chrono::microseconds diff;
// qkv to q_output, k_output, v_output
auto qkv = params.input.dataWithOffset(past_seq * (head_num + 2 * kv_head_num) * size_per_head);
printBufferData(params.input, "qkv");
tStart = std::chrono::steady_clock::now();
//if (params.weights.qkv_weight->bias) {
if (params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias) {
if (datatype == DataType::TYPE_FP32) {
auto bias_data_type = params.weights.qkv_weight->bias->type();
if (bias_data_type == DataType::TYPE_FP32) {
addQKVBias<float, float>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
} else if (bias_data_type == DataType::TYPE_FP16) {
addQKVBias<float, __fp16>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
} else if (datatype == DataType::TYPE_FP16) {
addQKVBias<__fp16, __fp16>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 1);
printBufferData(params.input, "biased qkv");
tStart = std::chrono::steady_clock::now();
if (params.configs.rope_config.style != RopeStyle::No) {
if (params.configs.rope_config.style == RopeStyle::Base) {
if (datatype == DataType::TYPE_FP32) {
halfRopeQK<float>(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step,
params.configs.rope_config.base, params.configs.rope_config.dim);
} else if (datatype == DataType::TYPE_FP16) {
halfRopeQK<__fp16>(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step,
params.configs.rope_config.base, params.configs.rope_config.dim);
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
} else {
throw std::runtime_error("SelfAttention RoPE type is not supported");
}
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 2);
printBufferData(params.input, "roped qkv");
tStart = std::chrono::steady_clock::now();
arm_compute::DataType acl_data_type = getAclDataType(datatype);
arm_compute::NESplit split;
arm_compute::Tensor src;
arm_compute::TensorInfo src_info = arm_compute::TensorInfo(
arm_compute::TensorShape(size_per_head, head_num + 2 * kv_head_num, seq_len, 1), 1, acl_data_type);
src.allocator()->init(src_info);
src.allocator()->import_memory(qkv);
std::vector<arm_compute::Tensor> dsts{};
std::vector<arm_compute::ITensor*> dsts_ptr;
arm_compute::TensorInfo q_info, kv_info;
arm_compute::Tensor q, k, v;
q_info = arm_compute::TensorInfo(
arm_compute::TensorShape(size_per_head, head_num, seq_len, 1), 1, acl_data_type);
kv_info = arm_compute::TensorInfo(
arm_compute::TensorShape(size_per_head, kv_head_num, seq_len, 1), 1, acl_data_type);
auto q_input = allocateBuffer({datatype, {1, seq_len, head_num, size_per_head}, AllocationType::HOST}, {});
auto k_input = allocateBuffer({datatype, {1, seq_len, kv_head_num, size_per_head}, AllocationType::HOST}, {});
auto v_input = allocateBuffer({datatype, {1, seq_len, kv_head_num, size_per_head}, AllocationType::HOST}, {});
q.allocator()->init(q_info);
k.allocator()->init(kv_info);
v.allocator()->init(kv_info);
q.allocator()->import_memory(q_input->data());
dsts.push_back(std::move(q));
k.allocator()->import_memory(k_input->data());
dsts.push_back(std::move(k));
v.allocator()->import_memory(v_input->data());
dsts.push_back(std::move(v));
for (auto& dst : dsts) {
dsts_ptr.emplace_back(&dst);
}
split.configure(&src, dsts_ptr, 1);
split.run();
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 0);
printBufferData(*q_input, "q_input");
printBufferData(*k_input, "k_input");
printBufferData(*v_input, "v_input");
tStart = std::chrono::steady_clock::now();
if (params.common.kv_cache.has_value()) {
updateKVCache(params, batch, step, k_input, v_input);
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 8);
BufferPtr k_in, k_out, v_in, v_out;
if (step == 0) {
/* Context. */
k_in = k_input;
v_in = v_input;
k_out = allocateBuffer({datatype, {1, kv_head_num, seq_len, size_per_head}, AllocationType::HOST}, {});
v_out = allocateBuffer({datatype, {1, kv_head_num, seq_len, size_per_head}, AllocationType::HOST}, {});
} else {
/* Decoder. Retrieve k/v cache data. */
k_in = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});
v_in = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});
assemCache(params, batch, k_in, v_in);
printBufferData(*k_in, "k_in");
printBufferData(*v_in, "v_in");
k_out = allocateBuffer({datatype, {1, kv_head_num, step + 1, size_per_head}, AllocationType::HOST}, {});
v_out = allocateBuffer({datatype, {1, kv_head_num, step + 1, size_per_head}, AllocationType::HOST}, {});
}
tStart = std::chrono::steady_clock::now();
auto q_output = allocateBuffer({datatype, {1, head_num, seq_len, size_per_head}, AllocationType::HOST}, {});
transposeDim12(std::move(q_input), q_output->data());
transposeDim12(std::move(k_in), k_out->data());
transposeDim12(std::move(v_in), v_out->data());
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 3);
printBufferData(*q_output, "q_output");
printBufferData(*k_out, "k_out");
printBufferData(*v_out, "v_out");
if (kv_head_num != head_num) {
/* repeat K/V */
size_t len;
if (step == 0) {
len = seq_len;
} else {
len = step + 1;
}
auto k_repeat = allocateBuffer({datatype, {1, head_num, len, size_per_head}, AllocationType::HOST}, {});
auto v_repeat = allocateBuffer({datatype, {1, head_num, len, size_per_head}, AllocationType::HOST}, {});
auto n_rep = head_num / kv_head_num;
const int N = kv_head_num;
parallel_for(N, [&](int tid) {
for (int i = 0; i < n_rep; i++) {
memcpy(k_repeat->dataWithOffset((tid * n_rep + i) * len * size_per_head), k_out->dataWithOffset(tid * len * size_per_head), k_out->sizeBytes() / kv_head_num);
memcpy(v_repeat->dataWithOffset((tid * n_rep + i) * len * size_per_head), v_out->dataWithOffset(tid * len * size_per_head), v_out->sizeBytes() / kv_head_num);
}
});
k_out = std::move(k_repeat);
v_out = std::move(v_repeat);
}
tStart = std::chrono::steady_clock::now();
auto qk_output = gemm_acl({*q_output,
*k_out,
std::nullopt,
nullptr,
DataType::TYPE_INVALID,
TransposeOperation::NONE,
TransposeOperation::TRANSPOSE});
printBufferData(*qk_output, "qk_output: ");
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 4);
printBufferData(*qk_output, "qk_output");
tStart = std::chrono::steady_clock::now();
float scale = (1.0f / sqrtf(size_per_head * 1.0f));
BufferPtr softmax_qk_output;
if (seq_len == 1) {
/* Decoder */
softmax_qk_output = softmax({qk_output, std::nullopt, std::nullopt, scale});
} else {
/* Context */
auto attention_mask = (*params.common.attention_mask).view(batch, 1);
softmax_qk_output = softmax({qk_output, attention_mask, std::nullopt, scale});
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 5);
printBufferData(*softmax_qk_output, "softmax_qk_output");
tStart = std::chrono::steady_clock::now();
auto qkv_output = gemm_acl({*softmax_qk_output, *v_out});
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 6);
tStart = std::chrono::steady_clock::now();
transposeDim12(qkv_output, params.output.dataWithOffset(past_seq * head_num * size_per_head));
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 7);
}
/* Inits array with per head base addresses */
template<typename T>
void getPerHeadArray(void *qkv, void *k_seen, void *v_seen, void **q_array, void **k_array, void **v_array,
int num_heads, int kv_num_heads, int head_size) {
const int N = 1 * num_heads;
parallel_for(N, [&](int tid) {
q_array[tid] = (T*)qkv + tid * head_size;
k_array[tid] = (T*)k_seen + tid * kv_num_heads / num_heads * head_size;
v_array[tid] = (T*)v_seen + tid * kv_num_heads / num_heads * head_size;
});
}
template<typename T>
void updateKVCacheStride(const AttentionModuleParams& params, void* input, int batch, size_t seq_len, size_t step) {
auto head_num = params.configs.head_num;
auto kv_head_num = params.configs.kv_head_num;
auto size_per_head = params.configs.size_per_head;
auto block_tokens = params.configs.tokens_per_block;
size_t block_num = (seq_len + block_tokens - 1) / block_tokens;
size_t block_offset = step / block_tokens;
auto elem_sz = params.input.typeSize();
size_t copied_len = 0;
void *k_block_addr;
void *v_block_addr;
for (int i = 0; i < block_num; i++) {
size_t len = std::min(block_tokens, seq_len - copied_len);
getCacheAddrFromIndex(params.common.kv_cache.value(), batch, i + block_offset, &k_block_addr, &v_block_addr);
T* k_input = (T*)input + (i * block_tokens) * (head_num + 2 * kv_head_num) * size_per_head + head_num * size_per_head;
T* v_input = (T*)input + (i * block_tokens) * (head_num + 2 * kv_head_num) * size_per_head + (head_num + kv_head_num) * size_per_head;
parallel_for(len, [&](int tid) {
memcpy((char*)k_block_addr + (step % block_tokens + tid) * elem_sz * kv_head_num * size_per_head,
k_input + tid * (head_num + 2 * kv_head_num) * size_per_head,
elem_sz * 1 * kv_head_num * size_per_head);
memcpy((char*)v_block_addr + (step % block_tokens + tid) * elem_sz * kv_head_num * size_per_head,
v_input + tid * (head_num + 2 * kv_head_num) * size_per_head,
elem_sz * 1 * kv_head_num * size_per_head);
});
copied_len += len;
}
}
void assemCacheArray(const AttentionModuleParams& params, BufferPtr k_out, BufferPtr v_out, size_t tokens_per_block) {
auto elem_sz = k_out->typeSize();
auto batch_size = k_out->shape()[0];
auto head_num = k_out->shape()[2];
auto kv_seq_len = k_out->shape()[1];
auto head_dim = k_out->shape()[3];
size_t blocks_per_batch = (kv_seq_len + tokens_per_block - 1) / tokens_per_block;
size_t copied_len;
void *k_block_addr;
void *v_block_addr;
for (int batch = 0; batch < batch_size; batch++) {
copied_len = 0;
for (int i = 0; i < blocks_per_batch; i++) {
size_t len = std::min(tokens_per_block, kv_seq_len - copied_len);
getCacheAddrFromIndex(params.common.kv_cache.value(), batch, i, &k_block_addr, &v_block_addr);
memcpy(k_out->dataWithOffset((batch * kv_seq_len + i * tokens_per_block) * head_num * head_dim),
k_block_addr,
elem_sz * len * head_num * head_dim);
memcpy(v_out->dataWithOffset((batch * kv_seq_len + i * tokens_per_block) * head_num * head_dim),
v_block_addr,
elem_sz * len * head_num * head_dim);
copied_len += len;
}
}
}
void ArmCpuDevice::runOneBatchStride(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step) {
auto datatype = params.input.type();
auto head_num = params.configs.head_num;
auto kv_head_num = params.configs.kv_head_num;
auto size_per_head = params.configs.size_per_head;
std::chrono::steady_clock::time_point tStart, tEnd;
std::chrono::microseconds diff;
if (datatype != DataType::TYPE_FP32) {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
// if (!params.common.kv_cache.has_value()) {
// throw std::runtime_error("kv cache block pointers can not be null");
// }
// Retrieve q/k/v by stride and not to split.
auto qkv = params.input.dataWithOffset(past_seq * (head_num + 2 * kv_head_num) * size_per_head);
printBufferData(params.input, "qkv");
tStart = std::chrono::steady_clock::now();
void *q_array[head_num], *k_array[head_num], *v_array[head_num];
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 0);
tStart = std::chrono::steady_clock::now();
//if (params.weights.qkv_weight->bias) {
if (params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias) {
auto bias_data_type = params.weights.qkv_weight->bias->type();
if (bias_data_type == DataType::TYPE_FP32) {
addQKVBias<float, float>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
} else if (bias_data_type == DataType::TYPE_FP16) {
addQKVBias<float, __fp16>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 1);
printBufferData(params.input, "biased qkv");
tStart = std::chrono::steady_clock::now();
if (params.configs.rope_config.style != RopeStyle::No) {
if (params.configs.rope_config.style == RopeStyle::Base) {
halfRopeQK<float>(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step,
params.configs.rope_config.base,
params.configs.rope_config.dim);
} else {
throw std::runtime_error("SelfAttention RoPE type is not supported");
}
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 2);
printBufferData(params.input, "roped qkv");
tStart = std::chrono::steady_clock::now();
if (params.common.kv_cache.has_value()) {
updateKVCacheStride<float>(params, qkv, batch, seq_len, step);
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 8);
BufferPtr k_buffer, v_buffer;
void *k_in, *v_in;
int stride_kv;
if (step == 0) {
/* Context */
k_in = (float *)qkv + head_num * size_per_head;
v_in = (float *)qkv + (head_num + kv_head_num) * size_per_head;
stride_kv = (head_num + 2 * kv_head_num) * size_per_head;
step = seq_len - 1; // Trick to unify context and decoder processes.
} else {
/* Decoder */
k_buffer = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});
v_buffer = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});
assemCache(params, batch, k_buffer, v_buffer);
printBufferData(*k_buffer, "k_buffer");
printBufferData(*v_buffer, "v_buffer");
k_in = k_buffer->data();
v_in = v_buffer->data();
stride_kv = kv_head_num * size_per_head;
}
tStart = std::chrono::steady_clock::now();
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 3);
tStart = std::chrono::steady_clock::now();
/* Re-init arrary as per [batch, num_heads]. */
getPerHeadArray<float>(qkv, k_in, v_in, q_array, k_array, v_array,
head_num, kv_head_num, size_per_head);
auto qk_output = allocateBuffer({datatype, {1, head_num, seq_len, step + 1}, AllocationType::HOST}, {"qk_output"});
const int stride_q = (head_num + 2 * kv_head_num) * size_per_head;
//const int N = 1 * head_num;
//parallel_for(N, [&](int tid) {
const int MHA_HEADS = 1 * head_num;
parallel_for(MHA_HEADS, [&](int tid) {
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, seq_len, step + 1, size_per_head, 1.0,
(const float*)q_array[tid], stride_q,
(const float*)k_array[tid], stride_kv, 0.0,
(float*)qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1);
});
printBufferData(*qk_output, "qk_output");
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 4);
tStart = std::chrono::steady_clock::now();
float scale = (1.0f / sqrtf(size_per_head * 1.0f));
BufferPtr softmax_qk_output;
if (seq_len == 1) {
/* Decoder */
softmax_qk_output = softmax({qk_output, std::nullopt, std::nullopt, scale});
} else {
/* Context */
RUNTIME_ASSERT_OP_ARG(params.common.attention_mask,
"attention_mask must be provided for default context attention implementation");
auto attention_mask = (*params.common.attention_mask).view(batch, 1);
printBufferData(attention_mask, "attention_mask");
softmax_qk_output = softmax({qk_output, attention_mask, std::nullopt, scale});
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 5);
printBufferData(*softmax_qk_output, "softmax_qk_output");
tStart = std::chrono::steady_clock::now();
//const int NN = 1 * head_num;
//parallel_for(NN, [&](int tid) {
// cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, seq_len, size_per_head, step + 1, 1.0,
// (const float*)softmax_qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1,
// (const float*)v_array[tid], stride_kv, 0.0,
// (float*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), head_num * size_per_head);
//});
if (!isKAIenabled) {
parallel_for(MHA_HEADS, [&](int tid) {
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, seq_len, size_per_head, step + 1, 1.0,
(const float*)softmax_qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1,
(const float*)v_array[tid], stride_kv, 0.0,
(float*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), head_num * size_per_head);
});
} else {
if (seq_len == 1) {
/* Decoder has higher performance with cblas for gemm. */
parallel_for(MHA_HEADS, [&](int tid) {
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, seq_len, size_per_head, step + 1, 1.0,
(const float*)softmax_qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1,
(const float*)v_array[tid], stride_kv, 0.0,
(float*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), head_num * size_per_head);
});
} else {
/* Context has higher performance with KleidiAI for gemm. */
const size_t bias_size = size_per_head;
float* bias = new float[bias_size];
memset(bias, 0, bias_size * sizeof(float));
const size_t M = seq_len;
const size_t N = size_per_head;
const size_t K = step + 1;
const size_t mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon(M, K, mr, kr, sr);
const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(N, K, nr, kr);
const size_t lhs_stride = K * sizeof(float);
const size_t rhs_stride = stride_kv * sizeof(float);
const size_t dst_stride_row = (head_num * size_per_head) * sizeof(float);
const size_t dst_stride_col = sizeof(float);
parallel_for(MHA_HEADS, [&](int tid) {
uint8_t *lhs_packed = new uint8_t[lhs_packed_size];
uint8_t *rhs_packed = new uint8_t[rhs_packed_size];
kai_run_lhs_quant_pack_bf16p8x4_f32_neon(M, K, mr, kr, sr, 0,
(const void*)softmax_qk_output->dataWithOffset(tid * M * K), lhs_stride, lhs_packed);
kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(1,
N,
K,
nr,
kr,
sr,
rhs_stride,
(const void*)v_array[tid],
bias,
NULL,
rhs_packed,
0,
NULL);
kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(M, N, K,
lhs_packed,
rhs_packed,
(void*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), dst_stride_row, dst_stride_col,
-FLT_MAX, FLT_MAX);
delete[] rhs_packed;
delete[] lhs_packed;
});
delete[] bias;
}
}
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 6);
tStart = std::chrono::steady_clock::now();
tEnd = std::chrono::steady_clock::now();
diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
logTime(diff, 7);
/* Print profile data at the end of operator unit test. */
// if (a_cnt_[0] == 24)
// printStat();
}
AttentionModuleOutput ArmCpuDevice::contextAttention(const AttentionModuleParams& params) {
auto batch_size = params.common.context_batch_size;
auto decoder_batch = params.common.decoder_batch_size;
size_t past_seq = 0;
if (params.input.type() == DataType::TYPE_FP32) {
for (int batch = 0; batch < batch_size; batch++) {
size_t context_len = *static_cast<int*>(params.common.input_lengths->dataWithOffset(decoder_batch + batch));
runOneBatchStride(params, past_seq, batch, context_len, 0);
past_seq += context_len;
}
} else if (params.input.type() == DataType::TYPE_FP16) {
RTP_LLM_LOG_WARNING("Attention performance could be suboptimal with FP16 input. Try FP32 input.");
for (int batch = 0; batch < batch_size; batch++) {
size_t context_len = *static_cast<int*>(params.common.input_lengths->dataWithOffset(decoder_batch + batch));
runOneBatch(params, past_seq, batch, context_len, 0);
past_seq += context_len;
}
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
}
AttentionModuleOutput ArmCpuDevice::decoderSelfAttention(const AttentionModuleParams& params) {
auto batch_size = params.common.decoder_batch_size;
if (params.input.type() == DataType::TYPE_FP32) {
for (int batch = 0; batch < batch_size; batch++) {
size_t step = *static_cast<int*>(params.common.sequence_lengths->dataWithOffset(batch));
runOneBatchStride(params, batch, batch, 1, step);
}
} else if (params.input.type() == DataType::TYPE_FP16) {
for (int batch = 0; batch < batch_size; batch++) {
size_t step = *static_cast<int*>(params.common.sequence_lengths->dataWithOffset(batch));
runOneBatch(params, batch, batch, 1, step);
}
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
}
} // namespace rtp_llm