in source/backend/opencl/execution/buffer/AttentionBufExecution.cpp [395:866]
ErrorCode AttentionBufExecution::longPrefillResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs){
auto query = inputs[0];
auto key = inputs[1];
auto value = inputs[2];
auto runtime = mOpenCLBackend->getOpenCLRuntime();
auto shape = query->shape();
int batch = shape[0];
int seqlen = shape[1];
int numHead = shape[2];
int kvNumHead = key->shape()[2];
int headDim = shape[3];
int group_size = numHead / kvNumHead;
float scale = 1.0 / sqrt(headDim);
mAlignQ = 32;
mAlignKV = 32;
mAlignHDK = 4;
mAlignHDN = 32;
float useMemorySize = 1.0 * ROUND_UP(seqlen, mAlignQ) / 1024.0 * ROUND_UP(seqlen, mAlignKV) / 1024.0 * batch * numHead;
// elementSize larger than 32M
if(useMemorySize > 32.0) {
mQseqSplitNum = useMemorySize >= 256.0 ? 8 : ((useMemorySize < 128.0) ? 2 : 4);
}
mKernel_rearrange_vec.resize(1); mGwsRearrgVec.resize(1); mLwsRearrgVec.resize(1);
mKernel_mask_vec.resize(1); mGwsMaskVec.resize(1); mLwsMaskVec.resize(1);
mKernel_qk_vec.resize(mQseqSplitNum); mGwsQkVec.resize(mQseqSplitNum); mLwsQkVec.resize(mQseqSplitNum);
mKernel_softmax_vec.resize(mQseqSplitNum); mGwsSoftMaxVec.resize(mQseqSplitNum); mLwsSoftMaxVec.resize(mQseqSplitNum);
mKernel_trans_vec.resize(mQseqSplitNum); mGwsTransVec.resize(mQseqSplitNum); mLwsTransVec.resize(mQseqSplitNum);
mKernel_qkv_vec.resize(mQseqSplitNum); mGwsQkvVec.resize(mQseqSplitNum); mLwsQkvVec.resize(mQseqSplitNum);
mKernel_clip_vec.resize(1); mGwsClipVec.resize(1); mLwsClipVec.resize(1);
mTempQ.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(headDim, mAlignHDK) * batch * numHead}));
mTempK.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignKV) * ROUND_UP(headDim, mAlignHDK) * batch * numHead}));
mTempV.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignKV) * ROUND_UP(headDim, mAlignHDN) * batch * numHead}));
if(mHasMask) {
if(mIsAddMask) {
mTempMask.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch}));
} else {
mTempMask.reset(Tensor::createDevice<uint32_t>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch}));
}
}
mTempQK.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch * numHead / mQseqSplitNum}));
mTempSoftMax.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch * numHead / mQseqSplitNum}));
mTempQKV.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(headDim, mAlignHDN) * batch * numHead}));
mOpenCLBackend->onAcquireBuffer(mTempQ.get(), Backend::DYNAMIC);
mOpenCLBackend->onAcquireBuffer(mTempK.get(), Backend::DYNAMIC);
mOpenCLBackend->onAcquireBuffer(mTempV.get(), Backend::DYNAMIC);
if(mHasMask) {
mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC);
}
mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC);
mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC);
mOpenCLBackend->onAcquireBuffer(mTempQKV.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mTempQ.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mTempK.get(), Backend::DYNAMIC);
if(mHasMask) {
mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC);
}
mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mTempV.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mTempQKV.get(), Backend::DYNAMIC);
// query: [batch, seqLenQ, headNum, headDim] -> mTempQ: [batch*headNum, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenQ, mAlignQ)]
// key: [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] -> mTempK: [batch*headNum/group, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenKV, mAlignKV)]
// value: [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] -> mTempV: [batch*headNum/group, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(headDim, mAlignHDK]
// key & value -> pastKey & pastValue (copy)
int seq_idx = 0;
// rearrange qkv
{
std::set<std::string> buildOption;
if((headDim % 4) != 0){
buildOption.emplace("-DHEADDIM_LEAVE");
}
// generate cache for every option
{
auto option = buildOption;
auto kernel = runtime->buildKernel("attention_buf", "rearrange_qkv", option, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
}
{
auto option = buildOption;
option.emplace("-DSEQLEN_LEAVE");
auto kernel = runtime->buildKernel("attention_buf", "rearrange_qkv", option, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
}
if((seqlen % 4) != 0){
buildOption.emplace("-DSEQLEN_LEAVE");
}
if(mNeedKvCache) {
buildOption.emplace("-DSAVE_KV");
}
int seq_len_pack_q = ROUND_UP(seqlen, mAlignQ);
int seq_len_pack_kv = ROUND_UP(mKvSeqlen, mAlignKV);
int head_dim_pack_qk = ROUND_UP(headDim, mAlignHDK);
int head_dim_pack_v = ROUND_UP(headDim, mAlignHDN);
int tile[4] = {mAlignQ, mAlignKV, mAlignHDK, mAlignHDN};
int shape[4] = {seqlen, mKvSeqlen, numHead, headDim};
int param[4] = {group_size, batch, 0, 0};
mKernel_rearrange_vec[seq_idx] = runtime->buildKernel("attention_buf", "rearrange_qkv", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_rearrange_vec[seq_idx]));
mGwsRearrgVec[seq_idx] = {static_cast<uint32_t>(ALIMAX(UP_DIV(seq_len_pack_q, 4), UP_DIV(seq_len_pack_kv, 4))), \
static_cast<uint32_t>(ALIMAX(UP_DIV(head_dim_pack_qk, 4), UP_DIV(head_dim_pack_v, 4))), \
static_cast<uint32_t>(batch*numHead)};
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, mGwsRearrgVec[seq_idx][0]);
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, mGwsRearrgVec[seq_idx][1]);
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, mGwsRearrgVec[seq_idx][2]);
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(query));
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(key));
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(value));
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempQ.get()));
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempK.get()));
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempV.get()));
if(mNeedKvCache) {
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, *mKVCacheCLManager->key());
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, *mKVCacheCLManager->value());
}
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, tile);
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, shape);
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, param);
ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, mKeyValueMaxlen);
MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_qkv");
mLwsRearrgVec[seq_idx] = localWS3DDefault(mGwsRearrgVec[seq_idx], maxWorkGroupSize, runtime, "rearrange_qkv", mKernel_rearrange_vec[seq_idx], mOpenCLBackend->getCLTuneLevel()).first;
mGwsRearrgVec[seq_idx][0] = ROUND_UP(mGwsRearrgVec[seq_idx][0], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][0]));
mGwsRearrgVec[seq_idx][1] = ROUND_UP(mGwsRearrgVec[seq_idx][1], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][1]));
mGwsRearrgVec[seq_idx][2] = ROUND_UP(mGwsRearrgVec[seq_idx][2], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][2]));
if(mNeedKvCache) {
mRgUpdateInfo.update_kernel_args.push_back({0, 9, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()});
mRgUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()});
}
mRgUpdateInfo.update_kernel_args.push_back({0, 14, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
mOpRecordUpdateInfo.emplace_back(&mRgUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], &mRgUpdateInfo);
}
// mask rearaange
if(mHasMask)
{
std::set<std::string> buildOption;
int seq_len_pack_q = ROUND_UP(seqlen, mAlignQ);
int seq_len_pack_kv = ROUND_UP(mKvSeqlen, mAlignKV);
int shape[4] = {seqlen, mKvSeqlen, mAlignQ, mAlignKV};
mKernel_mask_vec[seq_idx] = runtime->buildKernel("attention_buf", "rearrange_mask", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_mask_vec[seq_idx]));
mGwsMaskVec[seq_idx] = {static_cast<uint32_t>(UP_DIV(seq_len_pack_q, 4)), \
static_cast<uint32_t>(UP_DIV(seq_len_pack_kv, 4)), \
static_cast<uint32_t>(batch)};
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, mGwsMaskVec[seq_idx][0]);
ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, mGwsMaskVec[seq_idx][1]);
ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, mGwsMaskVec[seq_idx][2]);
ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, openCLBuffer(inputs[3]));
ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempMask.get()));
ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, shape);
MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_mask");
mLwsMaskVec[seq_idx] = localWS3DDefault(mGwsMaskVec[seq_idx], maxWorkGroupSize, runtime, "rearrange_mask", mKernel_mask_vec[seq_idx], mOpenCLBackend->getCLTuneLevel()).first;
mGwsMaskVec[seq_idx][0] = ROUND_UP(mGwsMaskVec[seq_idx][0], std::max((uint32_t)1, mLwsMaskVec[seq_idx][0]));
mGwsMaskVec[seq_idx][1] = ROUND_UP(mGwsMaskVec[seq_idx][1], std::max((uint32_t)1, mLwsMaskVec[seq_idx][1]));
mGwsMaskVec[seq_idx][2] = ROUND_UP(mGwsMaskVec[seq_idx][2], std::max((uint32_t)1, mLwsMaskVec[seq_idx][2]));
mOpenCLBackend->recordKernel3d(mKernel_mask_vec[seq_idx], mGwsMaskVec[seq_idx], mLwsMaskVec[seq_idx]);
}
for(int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) {
// qk matmul
{
// Q : [batch*headNum, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum] -> [B, K, M]
// K : [batch*headNum/group, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenKV, mAlignKV)] -> [B, K, N]
// QV: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)] -> [B, M, N]
int loop = batch * numHead;
int e_pack = ROUND_UP(seqlen, mAlignQ);
int e_pack_piece = e_pack / mQseqSplitNum;
int h_pack = ROUND_UP(mKvSeqlen, mAlignKV);
int l_pack = ROUND_UP(headDim, mAlignHDK);
std::set<std::string> buildOptions;
int biasType = 0;
std::vector<cl::Buffer> bufferVec = {openCLBuffer(mTempQ.get()), openCLBuffer(mTempK.get()), openCLBuffer(mTempQK.get())};
if(mHasMask) {
bufferVec.emplace_back(openCLBuffer(mTempMask.get()));
}
if(mIsAddMask) {
biasType = 2;
} else if(mHasMask) {
biasType = 5;// int value mask
}
uint32_t layout = 14; // 10 means mix-precision, 4 means layout
auto param = getGemmParams({(uint32_t)e_pack_piece, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)(biasType + 10*(group_size-1))}, bufferVec, mOpenCLBackend->getOpenCLRuntime(), mOpenCLBackend->getPrecision(), mOpenCLBackend->getCLTuneLevel());
int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13];
buildOptions.emplace("-DKWG=" + std::to_string(KWG));
buildOptions.emplace("-DKWI=" + std::to_string(KWI));
buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA));
buildOptions.emplace("-DMDIMC=" + std::to_string(MDIMC));
buildOptions.emplace("-DMWG=" + std::to_string(MWG));
buildOptions.emplace("-DNDIMB=" + std::to_string(NDIMB));
buildOptions.emplace("-DNDIMC=" + std::to_string(NDIMC));
buildOptions.emplace("-DNWG=" + std::to_string(NWG));
buildOptions.emplace("-DSA=" + std::to_string(SA));
buildOptions.emplace("-DSB=" + std::to_string(SB));
buildOptions.emplace("-DSTRM=" + std::to_string(STRM));
buildOptions.emplace("-DSTRN=" + std::to_string(STRN));
buildOptions.emplace("-DVWM=" + std::to_string(VWM));
buildOptions.emplace("-DVWN=" + std::to_string(VWN));
if(layout >= 4) {
buildOptions.emplace("-DOUTPUTMN");
}
int tileM = MWG;
int tileN = NWG;
int localM = MDIMC;
int localN = NDIMC;
if(mOpenCLBackend->getOpenCLRuntime()->getGpuType() == GpuType::ADRENO) {
buildOptions.emplace("-DUSE_CL_MAD=1");
buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1");
}
buildOptions.emplace("-DONLY_HAVE_ALPHA");
if(biasType >= 1) {
buildOptions.emplace("-DBIAS_TYPE=" + std::to_string(biasType));
}
buildOptions.emplace("-DPRECISION_COMPUTE=float -DCONVERT_PRECISION_COMPUTE=convert_float");
buildOptions.emplace("-DPRECISION_COMPUTE2=float2 -DCONVERT_PRECISION_COMPUTE2=convert_float2");
buildOptions.emplace("-DPRECISION_COMPUTE4=float4 -DCONVERT_PRECISION_COMPUTE4=convert_float4");
buildOptions.emplace("-DPRECISION_COMPUTE8=float8 -DCONVERT_PRECISION_COMPUTE8=convert_float8");
buildOptions.emplace("-DPRECISION_COMPUTE16=float16 -DCONVERT_PRECISION_COMPUTE16=convert_float16");
mKernel_qk_vec[seq_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions, mOpenCLBackend->getPrecision());
int out_per_thread_m = tileM / localM;
int out_per_thread_n = tileN / localN;
mGwsQkVec[seq_idx] = {static_cast<uint32_t>(e_pack_piece/out_per_thread_m), static_cast<uint32_t>(h_pack/out_per_thread_n), static_cast<uint32_t>(loop)};
mLwsQkVec[seq_idx] = {static_cast<uint32_t>(localM), static_cast<uint32_t>(localN), 1};
float alpha = scale;
float beta = 0.0f;
int batch_offset_a = e_pack * l_pack;
int batch_offset_b = h_pack * l_pack;
int batch_offset_c = e_pack_piece * h_pack;
int batch_offset[4] = {batch_offset_a, batch_offset_b, batch_offset_c, 0};
int base_ptr_offset[4] = {e_pack_piece * seq_idx, 0, 0, batch_offset_c * seq_idx};
int stride[4] = {e_pack, h_pack, h_pack, h_pack};
int group[4] = {1, group_size, 1, loop};
int idx = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, static_cast<int>(e_pack_piece));
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, static_cast<int>(h_pack));
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, static_cast<int>(l_pack));
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, alpha);
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, beta);
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQ.get()));
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempK.get()));
if(mHasMask) {
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempMask.get()));
}
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQK.get()));
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, batch_offset);
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, base_ptr_offset);
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, stride);
ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, group);
MNN_CHECK_CL_SUCCESS(ret, "setArg Self-Attention batchmatmul qk Kernel");
mOpenCLBackend->recordKernel3d(mKernel_qk_vec[seq_idx], mGwsQkVec[seq_idx], mLwsQkVec[seq_idx]);
}
// softmax
{
// QV: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)]
// Sotmax: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)]
// axis : 2 (last dim)
int softmaxShape[4];
softmaxShape[0] = batch*numHead;
softmaxShape[1] = ROUND_UP(seqlen, mAlignQ) / mQseqSplitNum;
softmaxShape[2] = ROUND_UP(mKvSeqlen, mAlignKV);
auto MaxLocalSize = std::min(std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize), static_cast<uint32_t>(256));
int localSize = 64;
std::set<std::string> buildOption;
buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize));
mKernel_softmax_vec[seq_idx] = runtime->buildKernel("self_attention_buf", "softmax_inside", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
mGwsSoftMaxVec[seq_idx] = {static_cast<uint32_t>(localSize), static_cast<uint32_t>(softmaxShape[1]), static_cast<uint32_t>(softmaxShape[0])};
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, mGwsSoftMaxVec[seq_idx][0]);
ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, mGwsSoftMaxVec[seq_idx][1]);
ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, mGwsSoftMaxVec[seq_idx][2]);
ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempQK.get()));
ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempSoftMax.get()));
ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, mKvSeqlen);
ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, softmaxShape);
MNN_CHECK_CL_SUCCESS(ret, "setArg Attention softmax");
mLwsSoftMaxVec[seq_idx] = {static_cast<uint32_t>(localSize), 1, 1};
mOpenCLBackend->recordKernel3d(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], mLwsSoftMaxVec[seq_idx]);
}
{
// Sotmax: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)]
// Trans: [Batch * numHead, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum]
int loop = batch * numHead;
int transDimW = ROUND_UP(seqlen, mAlignQ) / mQseqSplitNum;
int transDimH = ROUND_UP(mKvSeqlen, mAlignKV);
std::set<std::string> buildOptions;
mKernel_trans_vec[seq_idx] = runtime->buildKernel("self_attention_buf", "trans_3d_buf", buildOptions, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mKernel_trans_vec[seq_idx]));
mGwsTransVec[seq_idx] = {(uint32_t)transDimW/8, (uint32_t)transDimH/8, (uint32_t)(loop)};
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, mGwsTransVec[seq_idx][0]);
ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, mGwsTransVec[seq_idx][1]);
ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, mGwsTransVec[seq_idx][2]);
ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempSoftMax.get()));
ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempQK.get()));
ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, loop);
ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, transDimW);
ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, transDimH);
MNN_CHECK_CL_SUCCESS(ret, "setArg Attention transpose");
mLwsTransVec[seq_idx] = localWS3DDefault(mGwsTransVec[seq_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "trans_3d_buf", mKernel_trans_vec[seq_idx], mOpenCLBackend->getCLTuneLevel()).first;
mGwsTransVec[seq_idx][0] = ROUND_UP(mGwsTransVec[seq_idx][0], std::max((uint32_t)1, mLwsTransVec[seq_idx][0]));
mGwsTransVec[seq_idx][1] = ROUND_UP(mGwsTransVec[seq_idx][1], std::max((uint32_t)1, mLwsTransVec[seq_idx][1]));
mGwsTransVec[seq_idx][2] = ROUND_UP(mGwsTransVec[seq_idx][2], std::max((uint32_t)1, mLwsTransVec[seq_idx][2]));
mOpenCLBackend->recordKernel3d(mKernel_trans_vec[seq_idx], mGwsTransVec[seq_idx], mLwsTransVec[seq_idx]);
}
// qk * value
{
// Trans: [Batch * numHead, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum] -> [B, K, M]
// V : [Batch * numHead / group, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(headDim, mAlignHDN)] -> [B, K, N]
// QKV : [Batch * numHead, ROUND_UP(headDim, mAlignHDN), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum] -> [B, N, M]
int loop = batch * numHead;
int e_pack = ROUND_UP(seqlen, mAlignQ);
int e_pack_piece = e_pack / mQseqSplitNum;
int l_pack = ROUND_UP(mKvSeqlen, mAlignKV);
int h_pack = ROUND_UP(headDim, mAlignHDN);
std::set<std::string> buildOptions;
uint32_t layout = 0;
auto param = getGemmParams({(uint32_t)e_pack_piece, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)0}, {openCLBuffer(mTempQK.get()), openCLBuffer(mTempV.get()), openCLBuffer(mTempQKV.get())}, mOpenCLBackend->getOpenCLRuntime(), mOpenCLBackend->getPrecision(), mOpenCLBackend->getCLTuneLevel());
int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13];
buildOptions.emplace("-DKWG=" + std::to_string(KWG));
buildOptions.emplace("-DKWI=" + std::to_string(KWI));
buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA));
buildOptions.emplace("-DMDIMC=" + std::to_string(MDIMC));
buildOptions.emplace("-DMWG=" + std::to_string(MWG));
buildOptions.emplace("-DNDIMB=" + std::to_string(NDIMB));
buildOptions.emplace("-DNDIMC=" + std::to_string(NDIMC));
buildOptions.emplace("-DNWG=" + std::to_string(NWG));
buildOptions.emplace("-DSA=" + std::to_string(SA));
buildOptions.emplace("-DSB=" + std::to_string(SB));
buildOptions.emplace("-DSTRM=" + std::to_string(STRM));
buildOptions.emplace("-DSTRN=" + std::to_string(STRN));
buildOptions.emplace("-DVWM=" + std::to_string(VWM));
buildOptions.emplace("-DVWN=" + std::to_string(VWN));
if(layout >= 4) {
buildOptions.emplace("-DOUTPUTMN");
}
int tileM = MWG;
int tileN = NWG;
int localM = MDIMC;
int localN = NDIMC;
if(mOpenCLBackend->getOpenCLRuntime()->getGpuType() == GpuType::ADRENO) {
buildOptions.emplace("-DUSE_CL_MAD=1");
buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1");
}
mKernel_qkv_vec[seq_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions, mOpenCLBackend->getPrecision());
int out_per_thread_m = tileM / localM;
int out_per_thread_n = tileN / localN;
mGwsQkvVec[seq_idx] = {static_cast<uint32_t>(e_pack_piece/out_per_thread_m), static_cast<uint32_t>(h_pack/out_per_thread_n), static_cast<uint32_t>(loop)};
mLwsQkvVec[seq_idx] = {static_cast<uint32_t>(localM), static_cast<uint32_t>(localN), 1};
float alpha = 1.0f;
float beta = 0.0f;
int batch_offset_a = e_pack_piece * l_pack;
int batch_offset_b = h_pack * l_pack;
int batch_offset_c = e_pack * h_pack;
int batch_offset[4] = {batch_offset_a, batch_offset_b, batch_offset_c, 0};
int base_ptr_offset[4] = {0, 0, e_pack_piece * seq_idx, 0};
int stride[4] = {e_pack_piece, h_pack, e_pack, h_pack};
int group[4] = {1, group_size, 1, loop};
int idx = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, static_cast<int>(e_pack_piece));
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, static_cast<int>(h_pack));
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, static_cast<int>(l_pack));
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, alpha);
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, beta);
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQK.get()));
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempV.get()));
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQKV.get()));
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, batch_offset);
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, base_ptr_offset);
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, stride);
ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, group);
MNN_CHECK_CL_SUCCESS(ret, "setArg Self-Attention batchmatmul qkv Kernel");
mOpenCLBackend->recordKernel3d(mKernel_qkv_vec[seq_idx], mGwsQkvVec[seq_idx], mLwsQkvVec[seq_idx]);
}
}
seq_idx = 0;
// transpose to output
{
// QKV : [Batch * numHead, ROUND_UP(headDim, mAlignHDN), ROUND_UP(seqLenQ, mAlignQ)] -> [B, N, M]
// output: [batch, seqLenQ/4, headNum, headDim, seqLenQ_4]
std::set<std::string> buildOption;
mKernel_clip_vec[seq_idx] = runtime->buildKernel("attention_buf", "qkv_transpose_output", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_clip_vec[seq_idx]));
mGwsClipVec[seq_idx] = {static_cast<uint32_t>(UP_DIV(seqlen, 4)), static_cast<uint32_t>(UP_DIV(headDim, 4)), static_cast<uint32_t>(batch*numHead)};
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, mGwsClipVec[seq_idx][0]);
ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, mGwsClipVec[seq_idx][1]);
ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, mGwsClipVec[seq_idx][2]);
ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempQKV.get()));
ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, openCLBuffer(outputs[0]));
ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, mAlignQ);
ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, mAlignHDN);
ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, seqlen);
ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, numHead);
ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, headDim);
mLwsClipVec[seq_idx] = localWS3DDefault(mGwsClipVec[seq_idx], maxWorkGroupSize, runtime, "qkv_transpose_output", mKernel_clip_vec[seq_idx], mOpenCLBackend->getCLTuneLevel()).first;
mGwsClipVec[seq_idx][0] = ROUND_UP(mGwsClipVec[seq_idx][0], std::max((uint32_t)1, mLwsClipVec[seq_idx][0]));
mGwsClipVec[seq_idx][1] = ROUND_UP(mGwsClipVec[seq_idx][1], std::max((uint32_t)1, mLwsClipVec[seq_idx][1]));
mGwsClipVec[seq_idx][2] = ROUND_UP(mGwsClipVec[seq_idx][2], std::max((uint32_t)1, mLwsClipVec[seq_idx][2]));
MNN_CHECK_CL_SUCCESS(ret, "setArg qkv_transpose_output");
mOpenCLBackend->recordKernel3d(mKernel_clip_vec[seq_idx], mGwsClipVec[seq_idx], mLwsClipVec[seq_idx]);
}
mOpenCLBackend->endRecord(mRecording);
return NO_ERROR;
}