in source/backend/opencl/execution/buffer/AttentionBufExecution.cpp [1151:1437]
ErrorCode AttentionBufExecution::decodeResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs){
auto runtime = mOpenCLBackend->getOpenCLRuntime();
auto query = inputs[0];
auto key = inputs[1];
auto value = inputs[2];
auto shape = query->shape();
int batch = shape[0];
int seqlen = shape[1];
int numHead = shape[2];
int kvNumHead = key->shape()[2];
int headDim = shape[3];
int group_size = numHead / kvNumHead;
float scale = 1.0 / sqrt(headDim);
int mask_seqlen = seqlen;
int mask_kvlen = seqlen;
if(mHasMask) {
auto mask = inputs[3];
auto mask_shape = mask->shape();
mask_seqlen = mask_shape[2];
mask_kvlen = mask_shape[3];
}
cl::Buffer keyBuffer, valueBuffer;
if(mNeedKvCache) {
keyBuffer = *mKVCacheCLManager->key();
valueBuffer = *mKVCacheCLManager->value();
} else {
mTempK.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * ROUND_UP(headDim, 4) * numHead * batch}));
mTempV.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * ROUND_UP(headDim, 4) * numHead * batch}));
mOpenCLBackend->onAcquireBuffer(mTempK.get(), Backend::DYNAMIC);
mOpenCLBackend->onAcquireBuffer(mTempV.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mTempV.get(), Backend::DYNAMIC);
mOpenCLBackend->onReleaseBuffer(mTempK.get(), Backend::DYNAMIC);
keyBuffer = openCLBuffer(mTempK.get());
valueBuffer = openCLBuffer(mTempV.get());
}
mTempQK.reset(Tensor::createDevice<float>({mDecodeTmpMaxlen * numHead}));
mTempSoftMax.reset(Tensor::createDevice<float>({mDecodeTmpMaxlen * numHead}));
mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION);
mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION);
{
// rearrange key
std::set<std::string> buildOption;
mKernel_rearrange = runtime->buildKernel("attention_buf", "rearrange_k", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_rearrange));
mGlobalWorkSizeRearrg = {static_cast<uint32_t>(1), \
static_cast<uint32_t>(UP_DIV(headDim, 4)), \
static_cast<uint32_t>(kvNumHead * batch)};
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_rearrange->get().setArg(index++, mGlobalWorkSizeRearrg[0]);
ret |= mKernel_rearrange->get().setArg(index++, mGlobalWorkSizeRearrg[1]);
ret |= mKernel_rearrange->get().setArg(index++, mGlobalWorkSizeRearrg[2]);
ret |= mKernel_rearrange->get().setArg(index++, openCLBuffer(key));
ret |= mKernel_rearrange->get().setArg(index++, keyBuffer);
ret |= mKernel_rearrange->get().setArg(index++, mPastKvSeqlen);
ret |= mKernel_rearrange->get().setArg(index++, mKeyValueMaxlen);
ret |= mKernel_rearrange->get().setArg(index++, seqlen);
ret |= mKernel_rearrange->get().setArg(index++, kvNumHead);
ret |= mKernel_rearrange->get().setArg(index++, numHead);
ret |= mKernel_rearrange->get().setArg(index++, headDim);
MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_k");
mLocalWorkSizeRearrg = localWS3DDefault(mGlobalWorkSizeRearrg, maxWorkGroupSize, runtime, "rearrange_k", mKernel_rearrange, mOpenCLBackend->getCLTuneLevel()).first;
mGlobalWorkSizeRearrg[0] = ROUND_UP(mGlobalWorkSizeRearrg[0], std::max((uint32_t)1, mLocalWorkSizeRearrg[0]));
mGlobalWorkSizeRearrg[1] = ROUND_UP(mGlobalWorkSizeRearrg[1], std::max((uint32_t)1, mLocalWorkSizeRearrg[1]));
mGlobalWorkSizeRearrg[2] = ROUND_UP(mGlobalWorkSizeRearrg[2], std::max((uint32_t)1, mLocalWorkSizeRearrg[2]));
if(mNeedKvCache) {
mRgUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()});
mRgUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mPastKvSeqlen), &mPastKvSeqlen});
mRgUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
mOpRecordUpdateInfo.emplace_back(&mRgUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, &mRgUpdateInfo);
} else {
mOpenCLBackend->recordKernel3d(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg);
}
}
{
// matmul qk
std::set<std::string> buildOption;
buildOption.emplace("-DNUMHEAD_GROUP_SIZE=" + std::to_string(group_size));
mKernel_qk = runtime->buildKernel("attention_buf", "matmul_qk_decode", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
mGlobalWorkSizeQk = {static_cast<uint32_t>(UP_DIV(mKvSeqlen, 4)), static_cast<uint32_t>(numHead)};
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_qk));
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[0]);
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[1]);
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(query));
ret |= mKernel_qk->get().setArg(index++, keyBuffer);
ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempQK.get()));
ret |= mKernel_qk->get().setArg(index++, scale);
ret |= mKernel_qk->get().setArg(index++, mKvSeqlen);
ret |= mKernel_qk->get().setArg(index++, mKeyValueMaxlen);
ret |= mKernel_qk->get().setArg(index++, numHead);
ret |= mKernel_qk->get().setArg(index++, headDim);
MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qk_decode");
mLocalWorkSizeQk = localWS2DDefault(mGlobalWorkSizeQk, maxWorkGroupSize, runtime, "matmul_qk_decode", mKernel_qk, mOpenCLBackend->getCLTuneLevel()).first;
mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0]));
mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1]));
if(mNeedKvCache) {
mQkUpdateInfo.update_kernel_args.push_back({0, 0, sizeof(mGlobalWorkSizeQk0), &mGlobalWorkSizeQk0});
mQkUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()});
mQkUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()});
mQkUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKvSeqlen), &mKvSeqlen});
mQkUpdateInfo.update_kernel_args.push_back({0, 7, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
mQkGlobal_size[0] = mGlobalWorkSizeQk[0];
mQkGlobal_size[1] = mGlobalWorkSizeQk[1];
mQkUpdateInfo.update_global_size.push_back({0, mQkGlobal_size});
mOpRecordUpdateInfo.emplace_back(&mQkUpdateInfo);
mOpenCLBackend->recordKernel2d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, &mQkUpdateInfo);
} else {
mOpenCLBackend->recordKernel2d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk);
}
}
{
// softmax
int inside = 1;
int outside = numHead;
int localSize = 64;
std::set<std::string> buildOption;
buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize));
mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_in1_buf", buildOption, mOpenCLBackend->getPrecision());
mGlobalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), static_cast<uint32_t>(inside), static_cast<uint32_t>(outside)};
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_softmax));
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[0]);
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[1]);
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[2]);
ret |= mKernel_softmax->get().setArg(index++, openCLDeferBuffer(mTempQK.get()));
ret |= mKernel_softmax->get().setArg(index++, openCLDeferBuffer(mTempSoftMax.get()));
ret |= mKernel_softmax->get().setArg(index++, inside);
ret |= mKernel_softmax->get().setArg(index++, outside);
ret |= mKernel_softmax->get().setArg(index++, mKvSeqlen);
MNN_CHECK_CL_SUCCESS(ret, "setArg softmax");
mLocalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), 1, 1};
if(localSize == 1){
mLocalWorkSizeSoftMax = localWS3DDefault(mGlobalWorkSizeSoftMax, maxWorkGroupSize, runtime, "softmax", mKernel_softmax, mOpenCLBackend->getCLTuneLevel()).first;
}
mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0]));
mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1]));
mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2]));
if(mNeedKvCache) {
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()});
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()});
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 7, sizeof(mKvSeqlen), &mKvSeqlen});
mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, &mSoftMaxUpdateInfo);
} else {
mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax);
}
}
{
// rearrange value
std::set<std::string> buildOption;
mKernel_rearrangeV = runtime->buildKernel("attention_buf", "rearrange_v", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_rearrangeV));
mGlobalWorkSizeRearrgV = {static_cast<uint32_t>(UP_DIV(headDim, 4)), \
static_cast<uint32_t>(1), \
static_cast<uint32_t>(kvNumHead * batch)};
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_rearrangeV->get().setArg(index++, mGlobalWorkSizeRearrgV[0]);
ret |= mKernel_rearrangeV->get().setArg(index++, mGlobalWorkSizeRearrgV[1]);
ret |= mKernel_rearrangeV->get().setArg(index++, mGlobalWorkSizeRearrgV[2]);
ret |= mKernel_rearrangeV->get().setArg(index++, openCLBuffer(value));
ret |= mKernel_rearrangeV->get().setArg(index++, valueBuffer);
ret |= mKernel_rearrangeV->get().setArg(index++, mPastKvSeqlen);
ret |= mKernel_rearrangeV->get().setArg(index++, mKeyValueMaxlen);
ret |= mKernel_rearrangeV->get().setArg(index++, seqlen);
ret |= mKernel_rearrangeV->get().setArg(index++, kvNumHead);
ret |= mKernel_rearrangeV->get().setArg(index++, headDim);
MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_v");
mLocalWorkSizeRearrgV = localWS3DDefault(mGlobalWorkSizeRearrgV, maxWorkGroupSize, runtime, "rearrange_v", mKernel_rearrangeV, mOpenCLBackend->getCLTuneLevel()).first;
mGlobalWorkSizeRearrgV[0] = ROUND_UP(mGlobalWorkSizeRearrgV[0], std::max((uint32_t)1, mLocalWorkSizeRearrgV[0]));
mGlobalWorkSizeRearrgV[1] = ROUND_UP(mGlobalWorkSizeRearrgV[1], std::max((uint32_t)1, mLocalWorkSizeRearrgV[1]));
mGlobalWorkSizeRearrgV[2] = ROUND_UP(mGlobalWorkSizeRearrgV[2], std::max((uint32_t)1, mLocalWorkSizeRearrgV[2]));
if(mNeedKvCache) {
mRgVUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()});
mRgVUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mPastKvSeqlen), &mPastKvSeqlen});
mRgVUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
mOpRecordUpdateInfo.emplace_back(&mRgVUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, &mRgVUpdateInfo);
} else {
mOpenCLBackend->recordKernel3d(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV);
}
}
// qk * value
{
std::set<std::string> buildOption;
buildOption.emplace("-DNUMHEAD_GROUP_SIZE=" + std::to_string(group_size));
const int total_kernel = 2;
std::string kernelName[total_kernel] = {"matmul_qkv_decode_b4", "matmul_qkv_decode_b8"};
std::string unroll[total_kernel] = {"-DLOOP_UNROLL_4", "-DLOOP_UNROLL_8"};
int itemC[total_kernel] = {4, 8};
int actual_kernel = 2;
std::shared_ptr<KernelWrap> kernel[total_kernel * total_kernel];
std::vector<uint32_t> globalWorkSize[total_kernel * total_kernel];
std::vector<uint32_t> localWorkSize[total_kernel * total_kernel];
std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
for (int i = 0; i < actual_kernel; i++) {
for(int j = 0; j < actual_kernel; j++){
int knl_idx = i * total_kernel + j;
auto option = buildOption;
option.emplace(unroll[j]);
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("attention_buf", kernelName[i], option, mOpenCLBackend->getPrecision());
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(headDim, itemC[i])), static_cast<uint32_t>(numHead)};
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= kernel[knl_idx]->get().setArg(index++, globalWorkSize[knl_idx][0]);
ret |= kernel[knl_idx]->get().setArg(index++, globalWorkSize[knl_idx][1]);
ret |= kernel[knl_idx]->get().setArg(index++, openCLDeferBuffer(mTempSoftMax.get()));
ret |= kernel[knl_idx]->get().setArg(index++, valueBuffer);
ret |= kernel[knl_idx]->get().setArg(index++, openCLBuffer(outputs[0]));
ret |= kernel[knl_idx]->get().setArg(index++, mKvSeqlen);
ret |= kernel[knl_idx]->get().setArg(index++, mKeyValueMaxlen);
ret |= kernel[knl_idx]->get().setArg(index++, numHead);
ret |= kernel[knl_idx]->get().setArg(index++, kvNumHead);
ret |= kernel[knl_idx]->get().setArg(index++, headDim);
MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qkv_decode");
std::pair<std::vector<uint32_t>, int> retTune;
retTune = localWS2DDefault(globalWorkSize[knl_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName[i] + unroll[j], kernel[knl_idx], mOpenCLBackend->getCLTuneLevel());
if(min_cost.first > retTune.second) {
min_cost.first = retTune.second;
min_cost.second = knl_idx;
mLocalWorkSizeQkv = {retTune.first[0], retTune.first[1]};
}
}
}
int min_index = min_cost.second / 2;
int min_index_unroll = min_cost.second % 2;
mGlobalWorkSizeQkv = {globalWorkSize[min_cost.second][0], globalWorkSize[min_cost.second][1]};
buildOption.emplace(unroll[min_index_unroll]);
mKernel_qkv = runtime->buildKernel("attention_buf", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[0]);
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[1]);
ret |= mKernel_qkv->get().setArg(index++, openCLDeferBuffer(mTempSoftMax.get()));
ret |= mKernel_qkv->get().setArg(index++, valueBuffer);
ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(outputs[0]));
ret |= mKernel_qkv->get().setArg(index++, mKvSeqlen);
ret |= mKernel_qkv->get().setArg(index++, mKeyValueMaxlen);
ret |= mKernel_qkv->get().setArg(index++, numHead);
ret |= mKernel_qkv->get().setArg(index++, kvNumHead);
ret |= mKernel_qkv->get().setArg(index++, headDim);
MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qkv_decode");
mGlobalWorkSizeQkv[0] = ROUND_UP(mGlobalWorkSizeQkv[0], std::max((uint32_t)1, mLocalWorkSizeQkv[0]));
mGlobalWorkSizeQkv[1] = ROUND_UP(mGlobalWorkSizeQkv[1], std::max((uint32_t)1, mLocalWorkSizeQkv[1]));
if(mNeedKvCache) {
mQkvUpdateInfo.update_kernel_args.push_back({0, 2, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()});
mQkvUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()});
mQkvUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mKvSeqlen), &mKvSeqlen});
mQkvUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
mOpRecordUpdateInfo.emplace_back(&mQkvUpdateInfo);
mOpenCLBackend->recordKernel2d(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, &mQkvUpdateInfo);
} else {
mOpenCLBackend->recordKernel2d(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv);
}
}
mOpenCLBackend->endRecord(mRecording);
return NO_ERROR;
}