ErrorCode AttentionBufExecution::decodeResize()

in source/backend/opencl/execution/buffer/AttentionBufExecution.cpp [1151:1437]


ErrorCode AttentionBufExecution::decodeResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs){
    
    auto runtime = mOpenCLBackend->getOpenCLRuntime();
    auto query = inputs[0];
    auto key = inputs[1];
    auto value = inputs[2];
    auto shape = query->shape();
    
    int batch = shape[0];
    int seqlen = shape[1];
    int numHead = shape[2];
    int kvNumHead = key->shape()[2];
    int headDim = shape[3];
    int group_size = numHead / kvNumHead;
    float scale = 1.0 / sqrt(headDim);
    
    int mask_seqlen = seqlen;
    int mask_kvlen = seqlen;
    
    if(mHasMask) {
        auto mask = inputs[3];
        auto mask_shape = mask->shape();
        mask_seqlen = mask_shape[2];
        mask_kvlen  = mask_shape[3];
    }
    cl::Buffer keyBuffer, valueBuffer;
    if(mNeedKvCache) {
        keyBuffer = *mKVCacheCLManager->key();
        valueBuffer = *mKVCacheCLManager->value();
    } else {
        mTempK.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * ROUND_UP(headDim, 4) * numHead * batch}));
        mTempV.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * ROUND_UP(headDim, 4) * numHead * batch}));
        mOpenCLBackend->onAcquireBuffer(mTempK.get(), Backend::DYNAMIC);
        mOpenCLBackend->onAcquireBuffer(mTempV.get(), Backend::DYNAMIC);
        mOpenCLBackend->onReleaseBuffer(mTempV.get(), Backend::DYNAMIC);
        mOpenCLBackend->onReleaseBuffer(mTempK.get(), Backend::DYNAMIC);
        keyBuffer = openCLBuffer(mTempK.get());
        valueBuffer = openCLBuffer(mTempV.get());
    }
    
    mTempQK.reset(Tensor::createDevice<float>({mDecodeTmpMaxlen * numHead}));
    mTempSoftMax.reset(Tensor::createDevice<float>({mDecodeTmpMaxlen * numHead}));
    mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION);
    mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION);
    mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION);
    mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION);
    {
        // rearrange key
        std::set<std::string> buildOption;
        
        mKernel_rearrange = runtime->buildKernel("attention_buf", "rearrange_k", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
        auto maxWorkGroupSize  = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_rearrange));
        
        mGlobalWorkSizeRearrg = {static_cast<uint32_t>(1), \
                                static_cast<uint32_t>(UP_DIV(headDim, 4)), \
                                static_cast<uint32_t>(kvNumHead * batch)};

        uint32_t index = 0;
        cl_int ret = CL_SUCCESS;
        ret |= mKernel_rearrange->get().setArg(index++, mGlobalWorkSizeRearrg[0]);
        ret |= mKernel_rearrange->get().setArg(index++, mGlobalWorkSizeRearrg[1]);
        ret |= mKernel_rearrange->get().setArg(index++, mGlobalWorkSizeRearrg[2]);
        ret |= mKernel_rearrange->get().setArg(index++, openCLBuffer(key));
        ret |= mKernel_rearrange->get().setArg(index++, keyBuffer);
        ret |= mKernel_rearrange->get().setArg(index++, mPastKvSeqlen);
        ret |= mKernel_rearrange->get().setArg(index++, mKeyValueMaxlen);
        ret |= mKernel_rearrange->get().setArg(index++, seqlen);
        ret |= mKernel_rearrange->get().setArg(index++, kvNumHead);
        ret |= mKernel_rearrange->get().setArg(index++, numHead);
        ret |= mKernel_rearrange->get().setArg(index++, headDim);
        
        MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_k");
        mLocalWorkSizeRearrg = localWS3DDefault(mGlobalWorkSizeRearrg, maxWorkGroupSize, runtime, "rearrange_k", mKernel_rearrange, mOpenCLBackend->getCLTuneLevel()).first;
        mGlobalWorkSizeRearrg[0] = ROUND_UP(mGlobalWorkSizeRearrg[0], std::max((uint32_t)1, mLocalWorkSizeRearrg[0]));
        mGlobalWorkSizeRearrg[1] = ROUND_UP(mGlobalWorkSizeRearrg[1], std::max((uint32_t)1, mLocalWorkSizeRearrg[1]));
        mGlobalWorkSizeRearrg[2] = ROUND_UP(mGlobalWorkSizeRearrg[2], std::max((uint32_t)1, mLocalWorkSizeRearrg[2]));
        if(mNeedKvCache) {
            mRgUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()});
            mRgUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mPastKvSeqlen), &mPastKvSeqlen});
            mRgUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
            mOpRecordUpdateInfo.emplace_back(&mRgUpdateInfo);
            mOpenCLBackend->recordKernel3d(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, &mRgUpdateInfo);
        } else {
            mOpenCLBackend->recordKernel3d(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg);
        }
    }
    {
        // matmul qk
        std::set<std::string> buildOption;
        buildOption.emplace("-DNUMHEAD_GROUP_SIZE=" + std::to_string(group_size));
        mKernel_qk = runtime->buildKernel("attention_buf", "matmul_qk_decode", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
        mGlobalWorkSizeQk =  {static_cast<uint32_t>(UP_DIV(mKvSeqlen, 4)), static_cast<uint32_t>(numHead)};
        auto maxWorkGroupSize  = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_qk));
        
        uint32_t index = 0;
        cl_int ret = CL_SUCCESS;
        ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[0]);
        ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[1]);
        ret |= mKernel_qk->get().setArg(index++, openCLBuffer(query));
        ret |= mKernel_qk->get().setArg(index++, keyBuffer);
        ret |= mKernel_qk->get().setArg(index++, openCLDeferBuffer(mTempQK.get()));
        ret |= mKernel_qk->get().setArg(index++, scale);
        ret |= mKernel_qk->get().setArg(index++, mKvSeqlen);
        ret |= mKernel_qk->get().setArg(index++, mKeyValueMaxlen);
        ret |= mKernel_qk->get().setArg(index++, numHead);
        ret |= mKernel_qk->get().setArg(index++, headDim);
        MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qk_decode");
        
        mLocalWorkSizeQk = localWS2DDefault(mGlobalWorkSizeQk, maxWorkGroupSize, runtime, "matmul_qk_decode", mKernel_qk, mOpenCLBackend->getCLTuneLevel()).first;
        mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0]));
        mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1]));
        if(mNeedKvCache) {
            mQkUpdateInfo.update_kernel_args.push_back({0, 0, sizeof(mGlobalWorkSizeQk0), &mGlobalWorkSizeQk0});
            mQkUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()});
            mQkUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()});
            mQkUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKvSeqlen), &mKvSeqlen});
            mQkUpdateInfo.update_kernel_args.push_back({0, 7, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
            mQkGlobal_size[0] = mGlobalWorkSizeQk[0];
            mQkGlobal_size[1] = mGlobalWorkSizeQk[1];
            mQkUpdateInfo.update_global_size.push_back({0, mQkGlobal_size});
            mOpRecordUpdateInfo.emplace_back(&mQkUpdateInfo);
            mOpenCLBackend->recordKernel2d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, &mQkUpdateInfo);
        } else {
            mOpenCLBackend->recordKernel2d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk);
        }
    }
    {
        // softmax
        int inside  = 1;
        int outside = numHead;
        int localSize = 64;
        
        std::set<std::string> buildOption;
        buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize));
        mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_in1_buf", buildOption, mOpenCLBackend->getPrecision());
        mGlobalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), static_cast<uint32_t>(inside), static_cast<uint32_t>(outside)};
        auto maxWorkGroupSize  = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_softmax));
        
        uint32_t index = 0;
        cl_int ret = CL_SUCCESS;
        ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[0]);
        ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[1]);
        ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[2]);
        ret |= mKernel_softmax->get().setArg(index++, openCLDeferBuffer(mTempQK.get()));
        ret |= mKernel_softmax->get().setArg(index++, openCLDeferBuffer(mTempSoftMax.get()));
        ret |= mKernel_softmax->get().setArg(index++, inside);
        ret |= mKernel_softmax->get().setArg(index++, outside);
        ret |= mKernel_softmax->get().setArg(index++, mKvSeqlen);
        MNN_CHECK_CL_SUCCESS(ret, "setArg softmax");
        
        mLocalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), 1, 1};
        if(localSize == 1){
            mLocalWorkSizeSoftMax = localWS3DDefault(mGlobalWorkSizeSoftMax, maxWorkGroupSize, runtime, "softmax", mKernel_softmax, mOpenCLBackend->getCLTuneLevel()).first;
        }
        mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0]));
        mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1]));
        mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2]));
        if(mNeedKvCache) {
            mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLDeferBuffer(mTempQK.get())()});
            mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()});
            mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 7, sizeof(mKvSeqlen), &mKvSeqlen});
            mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo);
            mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, &mSoftMaxUpdateInfo);
        } else {
            mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax);
        }
    }
    {
        // rearrange value
        std::set<std::string> buildOption;
        
        mKernel_rearrangeV = runtime->buildKernel("attention_buf", "rearrange_v", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
        auto maxWorkGroupSize  = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_rearrangeV));
        
        mGlobalWorkSizeRearrgV = {static_cast<uint32_t>(UP_DIV(headDim, 4)), \
                                static_cast<uint32_t>(1), \
                                static_cast<uint32_t>(kvNumHead * batch)};

        uint32_t index = 0;
        cl_int ret = CL_SUCCESS;
        ret |= mKernel_rearrangeV->get().setArg(index++, mGlobalWorkSizeRearrgV[0]);
        ret |= mKernel_rearrangeV->get().setArg(index++, mGlobalWorkSizeRearrgV[1]);
        ret |= mKernel_rearrangeV->get().setArg(index++, mGlobalWorkSizeRearrgV[2]);
        ret |= mKernel_rearrangeV->get().setArg(index++, openCLBuffer(value));
        ret |= mKernel_rearrangeV->get().setArg(index++, valueBuffer);
        ret |= mKernel_rearrangeV->get().setArg(index++, mPastKvSeqlen);
        ret |= mKernel_rearrangeV->get().setArg(index++, mKeyValueMaxlen);
        ret |= mKernel_rearrangeV->get().setArg(index++, seqlen);
        ret |= mKernel_rearrangeV->get().setArg(index++, kvNumHead);
        ret |= mKernel_rearrangeV->get().setArg(index++, headDim);
        
        MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_v");
        mLocalWorkSizeRearrgV = localWS3DDefault(mGlobalWorkSizeRearrgV, maxWorkGroupSize, runtime, "rearrange_v", mKernel_rearrangeV, mOpenCLBackend->getCLTuneLevel()).first;
        mGlobalWorkSizeRearrgV[0] = ROUND_UP(mGlobalWorkSizeRearrgV[0], std::max((uint32_t)1, mLocalWorkSizeRearrgV[0]));
        mGlobalWorkSizeRearrgV[1] = ROUND_UP(mGlobalWorkSizeRearrgV[1], std::max((uint32_t)1, mLocalWorkSizeRearrgV[1]));
        mGlobalWorkSizeRearrgV[2] = ROUND_UP(mGlobalWorkSizeRearrgV[2], std::max((uint32_t)1, mLocalWorkSizeRearrgV[2]));
        if(mNeedKvCache) {
            mRgVUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()});
            mRgVUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mPastKvSeqlen), &mPastKvSeqlen});
            mRgVUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
            mOpRecordUpdateInfo.emplace_back(&mRgVUpdateInfo);
            mOpenCLBackend->recordKernel3d(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV, &mRgVUpdateInfo);
        } else {
            mOpenCLBackend->recordKernel3d(mKernel_rearrangeV, mGlobalWorkSizeRearrgV, mLocalWorkSizeRearrgV);
        }
    }
    // qk * value
    {
        std::set<std::string> buildOption;
        buildOption.emplace("-DNUMHEAD_GROUP_SIZE=" + std::to_string(group_size));
        const int total_kernel = 2;
        std::string kernelName[total_kernel] = {"matmul_qkv_decode_b4", "matmul_qkv_decode_b8"};
        std::string unroll[total_kernel] = {"-DLOOP_UNROLL_4", "-DLOOP_UNROLL_8"};
        int itemC[total_kernel] = {4, 8};
        int actual_kernel = 2;
        std::shared_ptr<KernelWrap> kernel[total_kernel * total_kernel];
        std::vector<uint32_t> globalWorkSize[total_kernel * total_kernel];
        std::vector<uint32_t> localWorkSize[total_kernel * total_kernel];
        std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
        
        for (int i = 0; i < actual_kernel; i++) {
            for(int j = 0; j < actual_kernel; j++){
                int knl_idx = i * total_kernel + j;
                auto option = buildOption;
                option.emplace(unroll[j]);
                kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("attention_buf", kernelName[i], option, mOpenCLBackend->getPrecision());
                uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
                globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(headDim, itemC[i])), static_cast<uint32_t>(numHead)};
                uint32_t index = 0;
                cl_int ret = CL_SUCCESS;
                ret |= kernel[knl_idx]->get().setArg(index++, globalWorkSize[knl_idx][0]);
                ret |= kernel[knl_idx]->get().setArg(index++, globalWorkSize[knl_idx][1]);
                ret |= kernel[knl_idx]->get().setArg(index++, openCLDeferBuffer(mTempSoftMax.get()));
                ret |= kernel[knl_idx]->get().setArg(index++, valueBuffer);
                ret |= kernel[knl_idx]->get().setArg(index++, openCLBuffer(outputs[0]));
                ret |= kernel[knl_idx]->get().setArg(index++, mKvSeqlen);
                ret |= kernel[knl_idx]->get().setArg(index++, mKeyValueMaxlen);
                ret |= kernel[knl_idx]->get().setArg(index++, numHead);
                ret |= kernel[knl_idx]->get().setArg(index++, kvNumHead);
                ret |= kernel[knl_idx]->get().setArg(index++, headDim);
                MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qkv_decode");
                std::pair<std::vector<uint32_t>, int> retTune;
                retTune = localWS2DDefault(globalWorkSize[knl_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName[i] + unroll[j], kernel[knl_idx], mOpenCLBackend->getCLTuneLevel());
                if(min_cost.first > retTune.second) {
                    min_cost.first = retTune.second;
                    min_cost.second = knl_idx;
                    mLocalWorkSizeQkv = {retTune.first[0], retTune.first[1]};
                }
            }
        }
        int min_index  = min_cost.second / 2;
        int min_index_unroll  = min_cost.second % 2;
        mGlobalWorkSizeQkv = {globalWorkSize[min_cost.second][0], globalWorkSize[min_cost.second][1]};
        buildOption.emplace(unroll[min_index_unroll]);
        mKernel_qkv = runtime->buildKernel("attention_buf", kernelName[min_index], buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
        
        uint32_t index = 0;
        cl_int ret = CL_SUCCESS;
        ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[0]);
        ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[1]);
        ret |= mKernel_qkv->get().setArg(index++, openCLDeferBuffer(mTempSoftMax.get()));
        ret |= mKernel_qkv->get().setArg(index++, valueBuffer);
        ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(outputs[0]));
        ret |= mKernel_qkv->get().setArg(index++, mKvSeqlen);
        ret |= mKernel_qkv->get().setArg(index++, mKeyValueMaxlen);
        ret |= mKernel_qkv->get().setArg(index++, numHead);
        ret |= mKernel_qkv->get().setArg(index++, kvNumHead);
        ret |= mKernel_qkv->get().setArg(index++, headDim);
        MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qkv_decode");
        
        mGlobalWorkSizeQkv[0] = ROUND_UP(mGlobalWorkSizeQkv[0], std::max((uint32_t)1, mLocalWorkSizeQkv[0]));
        mGlobalWorkSizeQkv[1] = ROUND_UP(mGlobalWorkSizeQkv[1], std::max((uint32_t)1, mLocalWorkSizeQkv[1]));
        if(mNeedKvCache) {
            mQkvUpdateInfo.update_kernel_args.push_back({0, 2, sizeof(cl_mem), &openCLDeferBuffer(mTempSoftMax.get())()});
            mQkvUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()});
            mQkvUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mKvSeqlen), &mKvSeqlen});
            mQkvUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
            mOpRecordUpdateInfo.emplace_back(&mQkvUpdateInfo);
            mOpenCLBackend->recordKernel2d(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, &mQkvUpdateInfo);
        } else {
            mOpenCLBackend->recordKernel2d(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv);
        }
    }
    mOpenCLBackend->endRecord(mRecording);

    return NO_ERROR;
}