ErrorCode AttentionBufExecution::longPrefillResize()

in source/backend/opencl/execution/buffer/AttentionBufExecution.cpp [395:866]


ErrorCode AttentionBufExecution::longPrefillResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs){
    
    auto query = inputs[0];
    auto key = inputs[1];
    auto value = inputs[2];
    auto runtime = mOpenCLBackend->getOpenCLRuntime();
    auto shape = query->shape();
       
    int batch = shape[0];
    int seqlen = shape[1];
    int numHead = shape[2];
    int kvNumHead = key->shape()[2];
    int headDim = shape[3];
    int group_size = numHead / kvNumHead;
    float scale = 1.0 / sqrt(headDim);

    mAlignQ = 32;
    mAlignKV = 32;
    mAlignHDK = 4;
    mAlignHDN = 32;
    
    float useMemorySize = 1.0 * ROUND_UP(seqlen, mAlignQ) / 1024.0 * ROUND_UP(seqlen, mAlignKV) / 1024.0 * batch * numHead;
    // elementSize larger than 32M
    if(useMemorySize > 32.0) {
        mQseqSplitNum = useMemorySize >= 256.0 ? 8 : ((useMemorySize < 128.0) ? 2 : 4);
    }
    
    mKernel_rearrange_vec.resize(1); mGwsRearrgVec.resize(1); mLwsRearrgVec.resize(1);
    mKernel_mask_vec.resize(1);     mGwsMaskVec.resize(1);   mLwsMaskVec.resize(1);
    mKernel_qk_vec.resize(mQseqSplitNum);       mGwsQkVec.resize(mQseqSplitNum);     mLwsQkVec.resize(mQseqSplitNum);
    mKernel_softmax_vec.resize(mQseqSplitNum);   mGwsSoftMaxVec.resize(mQseqSplitNum); mLwsSoftMaxVec.resize(mQseqSplitNum);
    mKernel_trans_vec.resize(mQseqSplitNum);     mGwsTransVec.resize(mQseqSplitNum);  mLwsTransVec.resize(mQseqSplitNum);
    mKernel_qkv_vec.resize(mQseqSplitNum);      mGwsQkvVec.resize(mQseqSplitNum);    mLwsQkvVec.resize(mQseqSplitNum);
    mKernel_clip_vec.resize(1);     mGwsClipVec.resize(1);   mLwsClipVec.resize(1);
    
    mTempQ.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(headDim, mAlignHDK) * batch * numHead}));
    mTempK.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignKV) * ROUND_UP(headDim, mAlignHDK) * batch * numHead}));
    mTempV.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignKV) * ROUND_UP(headDim, mAlignHDN) * batch * numHead}));
    if(mHasMask) {
        if(mIsAddMask) {
            mTempMask.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch}));
        } else {
            mTempMask.reset(Tensor::createDevice<uint32_t>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch}));
        }
    }
    mTempQK.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch * numHead / mQseqSplitNum}));
    mTempSoftMax.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(seqlen, mAlignKV) * batch * numHead / mQseqSplitNum}));
    mTempQKV.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, mAlignQ) * ROUND_UP(headDim, mAlignHDN) * batch * numHead}));
    
    
    mOpenCLBackend->onAcquireBuffer(mTempQ.get(), Backend::DYNAMIC);
    mOpenCLBackend->onAcquireBuffer(mTempK.get(), Backend::DYNAMIC);
    mOpenCLBackend->onAcquireBuffer(mTempV.get(), Backend::DYNAMIC);
    if(mHasMask) {
        mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC);
    }
    mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC);
    mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC);
    mOpenCLBackend->onAcquireBuffer(mTempQKV.get(), Backend::DYNAMIC);

    mOpenCLBackend->onReleaseBuffer(mTempQ.get(), Backend::DYNAMIC);
    mOpenCLBackend->onReleaseBuffer(mTempK.get(), Backend::DYNAMIC);
    if(mHasMask) {
        mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC);
    }
    mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC);
    mOpenCLBackend->onReleaseBuffer(mTempV.get(), Backend::DYNAMIC);
    mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC);
    mOpenCLBackend->onReleaseBuffer(mTempQKV.get(), Backend::DYNAMIC);
    
    // query: [batch, seqLenQ, headNum, headDim] -> mTempQ: [batch*headNum, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenQ, mAlignQ)]
    // key: [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] -> mTempK: [batch*headNum/group, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenKV, mAlignKV)]
    // value: [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] -> mTempV: [batch*headNum/group, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(headDim, mAlignHDK]
    // key & value -> pastKey & pastValue (copy)
    int seq_idx = 0;
    // rearrange qkv
    {
        std::set<std::string> buildOption;
        if((headDim % 4) != 0){
            buildOption.emplace("-DHEADDIM_LEAVE");
        }
        // generate cache for every option
        {
            auto option = buildOption;
            auto kernel = runtime->buildKernel("attention_buf", "rearrange_qkv", option, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
        }
        {
            auto option = buildOption;
            option.emplace("-DSEQLEN_LEAVE");
            auto kernel = runtime->buildKernel("attention_buf", "rearrange_qkv", option, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
        }
        if((seqlen % 4) != 0){
            buildOption.emplace("-DSEQLEN_LEAVE");
        }
        if(mNeedKvCache) {
            buildOption.emplace("-DSAVE_KV");
        }
        int seq_len_pack_q = ROUND_UP(seqlen, mAlignQ);
        int seq_len_pack_kv = ROUND_UP(mKvSeqlen, mAlignKV);
        
        int head_dim_pack_qk = ROUND_UP(headDim, mAlignHDK);
        int head_dim_pack_v = ROUND_UP(headDim, mAlignHDN);
        
        int tile[4] = {mAlignQ, mAlignKV, mAlignHDK, mAlignHDN};
        int shape[4] = {seqlen, mKvSeqlen, numHead, headDim};
        int param[4] = {group_size, batch, 0, 0};
        mKernel_rearrange_vec[seq_idx] = runtime->buildKernel("attention_buf", "rearrange_qkv", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
        auto maxWorkGroupSize  = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_rearrange_vec[seq_idx]));
        
        mGwsRearrgVec[seq_idx] = {static_cast<uint32_t>(ALIMAX(UP_DIV(seq_len_pack_q, 4), UP_DIV(seq_len_pack_kv, 4))), \
            static_cast<uint32_t>(ALIMAX(UP_DIV(head_dim_pack_qk, 4), UP_DIV(head_dim_pack_v, 4))), \
            static_cast<uint32_t>(batch*numHead)};
        
        uint32_t index = 0;
        cl_int ret = CL_SUCCESS;
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, mGwsRearrgVec[seq_idx][0]);
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, mGwsRearrgVec[seq_idx][1]);
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, mGwsRearrgVec[seq_idx][2]);
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(query));
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(key));
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(value));
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempQ.get()));
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempK.get()));
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempV.get()));
        if(mNeedKvCache) {
            ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, *mKVCacheCLManager->key());
            ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, *mKVCacheCLManager->value());
        }
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, tile);
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, shape);
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, param);
        ret |= mKernel_rearrange_vec[seq_idx]->get().setArg(index++, mKeyValueMaxlen);
        
        MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_qkv");
        mLwsRearrgVec[seq_idx] = localWS3DDefault(mGwsRearrgVec[seq_idx], maxWorkGroupSize, runtime, "rearrange_qkv", mKernel_rearrange_vec[seq_idx], mOpenCLBackend->getCLTuneLevel()).first;
        mGwsRearrgVec[seq_idx][0] = ROUND_UP(mGwsRearrgVec[seq_idx][0], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][0]));
        mGwsRearrgVec[seq_idx][1] = ROUND_UP(mGwsRearrgVec[seq_idx][1], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][1]));
        mGwsRearrgVec[seq_idx][2] = ROUND_UP(mGwsRearrgVec[seq_idx][2], std::max((uint32_t)1, mLwsRearrgVec[seq_idx][2]));
        if(mNeedKvCache) {
            mRgUpdateInfo.update_kernel_args.push_back({0, 9, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()});
            mRgUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()});
        }
        mRgUpdateInfo.update_kernel_args.push_back({0, 14, sizeof(mKeyValueMaxlen), &mKeyValueMaxlen});
        mOpRecordUpdateInfo.emplace_back(&mRgUpdateInfo);
        mOpenCLBackend->recordKernel3d(mKernel_rearrange_vec[seq_idx], mGwsRearrgVec[seq_idx], mLwsRearrgVec[seq_idx], &mRgUpdateInfo);
    }
    
    // mask rearaange
    if(mHasMask)
    {
        std::set<std::string> buildOption;
        
        int seq_len_pack_q = ROUND_UP(seqlen, mAlignQ);
        int seq_len_pack_kv = ROUND_UP(mKvSeqlen, mAlignKV);
        int shape[4] = {seqlen, mKvSeqlen, mAlignQ, mAlignKV};
        
        mKernel_mask_vec[seq_idx] = runtime->buildKernel("attention_buf", "rearrange_mask", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
        auto maxWorkGroupSize  = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_mask_vec[seq_idx]));
        
        mGwsMaskVec[seq_idx] = {static_cast<uint32_t>(UP_DIV(seq_len_pack_q, 4)), \
            static_cast<uint32_t>(UP_DIV(seq_len_pack_kv, 4)), \
            static_cast<uint32_t>(batch)};
        
        uint32_t index = 0;
        cl_int ret = CL_SUCCESS;
        ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, mGwsMaskVec[seq_idx][0]);
        ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, mGwsMaskVec[seq_idx][1]);
        ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, mGwsMaskVec[seq_idx][2]);
        ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, openCLBuffer(inputs[3]));
        ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempMask.get()));
        ret |= mKernel_mask_vec[seq_idx]->get().setArg(index++, shape);
        
        MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_mask");
        mLwsMaskVec[seq_idx] = localWS3DDefault(mGwsMaskVec[seq_idx], maxWorkGroupSize, runtime, "rearrange_mask", mKernel_mask_vec[seq_idx], mOpenCLBackend->getCLTuneLevel()).first;
        mGwsMaskVec[seq_idx][0] = ROUND_UP(mGwsMaskVec[seq_idx][0], std::max((uint32_t)1, mLwsMaskVec[seq_idx][0]));
        mGwsMaskVec[seq_idx][1] = ROUND_UP(mGwsMaskVec[seq_idx][1], std::max((uint32_t)1, mLwsMaskVec[seq_idx][1]));
        mGwsMaskVec[seq_idx][2] = ROUND_UP(mGwsMaskVec[seq_idx][2], std::max((uint32_t)1, mLwsMaskVec[seq_idx][2]));
        mOpenCLBackend->recordKernel3d(mKernel_mask_vec[seq_idx], mGwsMaskVec[seq_idx], mLwsMaskVec[seq_idx]);
    }

    for(int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) {
        // qk matmul
        {
            // Q : [batch*headNum, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum] -> [B, K, M]
            // K : [batch*headNum/group, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenKV, mAlignKV)] -> [B, K, N]
            // QV: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)]   -> [B, M, N]
            int loop = batch * numHead;
            int e_pack = ROUND_UP(seqlen, mAlignQ);
            int e_pack_piece = e_pack / mQseqSplitNum;
            int h_pack = ROUND_UP(mKvSeqlen, mAlignKV);
            int l_pack = ROUND_UP(headDim, mAlignHDK);
            
            std::set<std::string> buildOptions;
            
            int biasType = 0;
            std::vector<cl::Buffer> bufferVec = {openCLBuffer(mTempQ.get()), openCLBuffer(mTempK.get()), openCLBuffer(mTempQK.get())};
            if(mHasMask) {
                bufferVec.emplace_back(openCLBuffer(mTempMask.get()));
            }
            if(mIsAddMask) {
                biasType = 2;
            } else if(mHasMask) {
                biasType = 5;// int value mask
            }
            uint32_t layout = 14; // 10 means mix-precision, 4 means layout
            auto param = getGemmParams({(uint32_t)e_pack_piece, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)(biasType + 10*(group_size-1))}, bufferVec, mOpenCLBackend->getOpenCLRuntime(), mOpenCLBackend->getPrecision(), mOpenCLBackend->getCLTuneLevel());
            
            int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13];
            buildOptions.emplace("-DKWG=" + std::to_string(KWG));
            buildOptions.emplace("-DKWI=" + std::to_string(KWI));
            buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA));
            buildOptions.emplace("-DMDIMC=" + std::to_string(MDIMC));
            buildOptions.emplace("-DMWG=" + std::to_string(MWG));
            buildOptions.emplace("-DNDIMB=" + std::to_string(NDIMB));
            buildOptions.emplace("-DNDIMC=" + std::to_string(NDIMC));
            buildOptions.emplace("-DNWG=" + std::to_string(NWG));
            buildOptions.emplace("-DSA=" + std::to_string(SA));
            buildOptions.emplace("-DSB=" + std::to_string(SB));
            buildOptions.emplace("-DSTRM=" + std::to_string(STRM));
            buildOptions.emplace("-DSTRN=" + std::to_string(STRN));
            buildOptions.emplace("-DVWM=" + std::to_string(VWM));
            buildOptions.emplace("-DVWN=" + std::to_string(VWN));
            if(layout >= 4) {
                buildOptions.emplace("-DOUTPUTMN");
            }
            
            int tileM = MWG;
            int tileN = NWG;
            int localM = MDIMC;
            int localN = NDIMC;
            
            if(mOpenCLBackend->getOpenCLRuntime()->getGpuType() == GpuType::ADRENO) {
                buildOptions.emplace("-DUSE_CL_MAD=1");
                buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1");
            }
            buildOptions.emplace("-DONLY_HAVE_ALPHA");
            if(biasType >= 1) {
                buildOptions.emplace("-DBIAS_TYPE=" + std::to_string(biasType));
            }
            
            buildOptions.emplace("-DPRECISION_COMPUTE=float -DCONVERT_PRECISION_COMPUTE=convert_float");
            buildOptions.emplace("-DPRECISION_COMPUTE2=float2 -DCONVERT_PRECISION_COMPUTE2=convert_float2");
            buildOptions.emplace("-DPRECISION_COMPUTE4=float4 -DCONVERT_PRECISION_COMPUTE4=convert_float4");
            buildOptions.emplace("-DPRECISION_COMPUTE8=float8 -DCONVERT_PRECISION_COMPUTE8=convert_float8");
            buildOptions.emplace("-DPRECISION_COMPUTE16=float16 -DCONVERT_PRECISION_COMPUTE16=convert_float16");
            
            mKernel_qk_vec[seq_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions, mOpenCLBackend->getPrecision());
            
            int out_per_thread_m = tileM / localM;
            int out_per_thread_n = tileN / localN;
            
            mGwsQkVec[seq_idx] = {static_cast<uint32_t>(e_pack_piece/out_per_thread_m), static_cast<uint32_t>(h_pack/out_per_thread_n), static_cast<uint32_t>(loop)};
            mLwsQkVec[seq_idx] = {static_cast<uint32_t>(localM), static_cast<uint32_t>(localN), 1};
            
            float alpha = scale;
            float beta = 0.0f;
            int batch_offset_a = e_pack * l_pack;
            int batch_offset_b = h_pack * l_pack;
            int batch_offset_c = e_pack_piece * h_pack;
            
            int batch_offset[4] = {batch_offset_a, batch_offset_b, batch_offset_c, 0};
            int base_ptr_offset[4] = {e_pack_piece * seq_idx, 0, 0, batch_offset_c * seq_idx};
            int stride[4] = {e_pack, h_pack, h_pack, h_pack};
            int group[4] = {1, group_size, 1, loop};
            
            int idx            = 0;
            cl_int ret = CL_SUCCESS;
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, static_cast<int>(e_pack_piece));
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, static_cast<int>(h_pack));
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, static_cast<int>(l_pack));
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, alpha);
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, beta);
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQ.get()));
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempK.get()));
            if(mHasMask) {
                ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempMask.get()));
            }
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQK.get()));
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, batch_offset);
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, base_ptr_offset);
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, stride);
            ret |= mKernel_qk_vec[seq_idx]->get().setArg(idx++, group);
            MNN_CHECK_CL_SUCCESS(ret, "setArg Self-Attention batchmatmul qk Kernel");
            mOpenCLBackend->recordKernel3d(mKernel_qk_vec[seq_idx], mGwsQkVec[seq_idx], mLwsQkVec[seq_idx]);
        }
        
        // softmax
        {
            // QV:     [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)]
            // Sotmax: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)]
            // axis  : 2 (last dim)
            int softmaxShape[4];
            softmaxShape[0] = batch*numHead;
            softmaxShape[1] = ROUND_UP(seqlen, mAlignQ) / mQseqSplitNum;
            softmaxShape[2] = ROUND_UP(mKvSeqlen, mAlignKV);
            
            auto MaxLocalSize = std::min(std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize), static_cast<uint32_t>(256));
            int localSize = 64;
            
            std::set<std::string> buildOption;
            buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize));
            
            mKernel_softmax_vec[seq_idx] = runtime->buildKernel("self_attention_buf", "softmax_inside", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
            mGwsSoftMaxVec[seq_idx] =  {static_cast<uint32_t>(localSize), static_cast<uint32_t>(softmaxShape[1]), static_cast<uint32_t>(softmaxShape[0])};
            
            uint32_t index = 0;
            cl_int ret = CL_SUCCESS;
            ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, mGwsSoftMaxVec[seq_idx][0]);
            ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, mGwsSoftMaxVec[seq_idx][1]);
            ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, mGwsSoftMaxVec[seq_idx][2]);
            ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempQK.get()));
            ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempSoftMax.get()));
            ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, mKvSeqlen);
            ret |= mKernel_softmax_vec[seq_idx]->get().setArg(index++, softmaxShape);
            MNN_CHECK_CL_SUCCESS(ret, "setArg Attention softmax");
            
            mLwsSoftMaxVec[seq_idx] = {static_cast<uint32_t>(localSize), 1, 1};
            mOpenCLBackend->recordKernel3d(mKernel_softmax_vec[seq_idx], mGwsSoftMaxVec[seq_idx], mLwsSoftMaxVec[seq_idx]);
        }
        {
            // Sotmax: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum, ROUND_UP(seqLenKV, mAlignKV)]
            // Trans:  [Batch * numHead, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum]
            int loop = batch * numHead;
            int transDimW = ROUND_UP(seqlen, mAlignQ) / mQseqSplitNum;
            int transDimH = ROUND_UP(mKvSeqlen, mAlignKV);
            
            std::set<std::string> buildOptions;
            mKernel_trans_vec[seq_idx] = runtime->buildKernel("self_attention_buf", "trans_3d_buf", buildOptions, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
            uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mKernel_trans_vec[seq_idx]));
            
            mGwsTransVec[seq_idx] = {(uint32_t)transDimW/8, (uint32_t)transDimH/8, (uint32_t)(loop)};
            
            uint32_t index = 0;
            cl_int ret = CL_SUCCESS;
            ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, mGwsTransVec[seq_idx][0]);
            ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, mGwsTransVec[seq_idx][1]);
            ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, mGwsTransVec[seq_idx][2]);
            ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempSoftMax.get()));
            ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempQK.get()));
            ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, loop);
            ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, transDimW);
            ret |= mKernel_trans_vec[seq_idx]->get().setArg(index++, transDimH);
            MNN_CHECK_CL_SUCCESS(ret, "setArg Attention transpose");
            mLwsTransVec[seq_idx] = localWS3DDefault(mGwsTransVec[seq_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "trans_3d_buf", mKernel_trans_vec[seq_idx], mOpenCLBackend->getCLTuneLevel()).first;
            
            mGwsTransVec[seq_idx][0] = ROUND_UP(mGwsTransVec[seq_idx][0], std::max((uint32_t)1, mLwsTransVec[seq_idx][0]));
            mGwsTransVec[seq_idx][1] = ROUND_UP(mGwsTransVec[seq_idx][1], std::max((uint32_t)1, mLwsTransVec[seq_idx][1]));
            mGwsTransVec[seq_idx][2] = ROUND_UP(mGwsTransVec[seq_idx][2], std::max((uint32_t)1, mLwsTransVec[seq_idx][2]));
            
            mOpenCLBackend->recordKernel3d(mKernel_trans_vec[seq_idx], mGwsTransVec[seq_idx], mLwsTransVec[seq_idx]);
        }
        
        // qk * value
        {
            // Trans: [Batch * numHead, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum]   -> [B, K, M]
            // V :     [Batch * numHead / group, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(headDim, mAlignHDN)] -> [B, K, N]
            // QKV :   [Batch * numHead, ROUND_UP(headDim, mAlignHDN), ROUND_UP(seqLenQ, mAlignQ) / mQseqSplitNum] -> [B, N, M]
            
            int loop = batch * numHead;
            int e_pack = ROUND_UP(seqlen, mAlignQ);
            int e_pack_piece = e_pack / mQseqSplitNum;
            int l_pack = ROUND_UP(mKvSeqlen, mAlignKV);
            int h_pack = ROUND_UP(headDim, mAlignHDN);
            
            std::set<std::string> buildOptions;
            
            uint32_t layout = 0;
            auto param = getGemmParams({(uint32_t)e_pack_piece, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)0}, {openCLBuffer(mTempQK.get()), openCLBuffer(mTempV.get()), openCLBuffer(mTempQKV.get())}, mOpenCLBackend->getOpenCLRuntime(), mOpenCLBackend->getPrecision(), mOpenCLBackend->getCLTuneLevel());
            
            int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13];
            buildOptions.emplace("-DKWG=" + std::to_string(KWG));
            buildOptions.emplace("-DKWI=" + std::to_string(KWI));
            buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA));
            buildOptions.emplace("-DMDIMC=" + std::to_string(MDIMC));
            buildOptions.emplace("-DMWG=" + std::to_string(MWG));
            buildOptions.emplace("-DNDIMB=" + std::to_string(NDIMB));
            buildOptions.emplace("-DNDIMC=" + std::to_string(NDIMC));
            buildOptions.emplace("-DNWG=" + std::to_string(NWG));
            buildOptions.emplace("-DSA=" + std::to_string(SA));
            buildOptions.emplace("-DSB=" + std::to_string(SB));
            buildOptions.emplace("-DSTRM=" + std::to_string(STRM));
            buildOptions.emplace("-DSTRN=" + std::to_string(STRN));
            buildOptions.emplace("-DVWM=" + std::to_string(VWM));
            buildOptions.emplace("-DVWN=" + std::to_string(VWN));
            if(layout >= 4) {
                buildOptions.emplace("-DOUTPUTMN");
            }
            
            int tileM = MWG;
            int tileN = NWG;
            int localM = MDIMC;
            int localN = NDIMC;
            
            if(mOpenCLBackend->getOpenCLRuntime()->getGpuType() == GpuType::ADRENO) {
                buildOptions.emplace("-DUSE_CL_MAD=1");
                buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1");
            }
            
            mKernel_qkv_vec[seq_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions, mOpenCLBackend->getPrecision());
            
            int out_per_thread_m = tileM / localM;
            int out_per_thread_n = tileN / localN;
            
            mGwsQkvVec[seq_idx] = {static_cast<uint32_t>(e_pack_piece/out_per_thread_m), static_cast<uint32_t>(h_pack/out_per_thread_n), static_cast<uint32_t>(loop)};
            mLwsQkvVec[seq_idx] = {static_cast<uint32_t>(localM), static_cast<uint32_t>(localN), 1};
            
            float alpha = 1.0f;
            float beta = 0.0f;
            int batch_offset_a = e_pack_piece * l_pack;
            int batch_offset_b = h_pack * l_pack;
            int batch_offset_c = e_pack * h_pack;
            int batch_offset[4] = {batch_offset_a, batch_offset_b, batch_offset_c, 0};
            int base_ptr_offset[4] = {0, 0, e_pack_piece * seq_idx, 0};
            int stride[4] = {e_pack_piece, h_pack, e_pack, h_pack};
            int group[4] = {1, group_size, 1, loop};
            
            int idx            = 0;
            cl_int ret = CL_SUCCESS;
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, static_cast<int>(e_pack_piece));
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, static_cast<int>(h_pack));
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, static_cast<int>(l_pack));
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, alpha);
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, beta);
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQK.get()));
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempV.get()));
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQKV.get()));
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, batch_offset);
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, base_ptr_offset);
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, stride);
            ret |= mKernel_qkv_vec[seq_idx]->get().setArg(idx++, group);
            MNN_CHECK_CL_SUCCESS(ret, "setArg Self-Attention batchmatmul qkv Kernel");
            mOpenCLBackend->recordKernel3d(mKernel_qkv_vec[seq_idx], mGwsQkvVec[seq_idx], mLwsQkvVec[seq_idx]);
        }
    }
    
    seq_idx = 0;
    // transpose to output
    {
        // QKV :   [Batch * numHead, ROUND_UP(headDim, mAlignHDN), ROUND_UP(seqLenQ, mAlignQ)] -> [B, N, M]
        // output: [batch, seqLenQ/4, headNum, headDim, seqLenQ_4]
        std::set<std::string> buildOption;
        
        mKernel_clip_vec[seq_idx] = runtime->buildKernel("attention_buf", "qkv_transpose_output", buildOption, mOpenCLBackend->getPrecision(), inputs[0], outputs[0]);
        auto maxWorkGroupSize  = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_clip_vec[seq_idx]));
        
        mGwsClipVec[seq_idx] = {static_cast<uint32_t>(UP_DIV(seqlen, 4)), static_cast<uint32_t>(UP_DIV(headDim, 4)), static_cast<uint32_t>(batch*numHead)};
        
        uint32_t index = 0;
        cl_int ret = CL_SUCCESS;
        ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, mGwsClipVec[seq_idx][0]);
        ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, mGwsClipVec[seq_idx][1]);
        ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, mGwsClipVec[seq_idx][2]);
        ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, openCLBuffer(mTempQKV.get()));
        ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, openCLBuffer(outputs[0]));
        ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, mAlignQ);
        ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, mAlignHDN);
        ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, seqlen);
        ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, numHead);
        ret |= mKernel_clip_vec[seq_idx]->get().setArg(index++, headDim);
        
        mLwsClipVec[seq_idx] = localWS3DDefault(mGwsClipVec[seq_idx], maxWorkGroupSize, runtime, "qkv_transpose_output", mKernel_clip_vec[seq_idx], mOpenCLBackend->getCLTuneLevel()).first;
        mGwsClipVec[seq_idx][0] = ROUND_UP(mGwsClipVec[seq_idx][0], std::max((uint32_t)1, mLwsClipVec[seq_idx][0]));
        mGwsClipVec[seq_idx][1] = ROUND_UP(mGwsClipVec[seq_idx][1], std::max((uint32_t)1, mLwsClipVec[seq_idx][1]));
        mGwsClipVec[seq_idx][2] = ROUND_UP(mGwsClipVec[seq_idx][2], std::max((uint32_t)1, mLwsClipVec[seq_idx][2]));
        
        MNN_CHECK_CL_SUCCESS(ret, "setArg qkv_transpose_output");
        mOpenCLBackend->recordKernel3d(mKernel_clip_vec[seq_idx], mGwsClipVec[seq_idx], mLwsClipVec[seq_idx]);
    }
    mOpenCLBackend->endRecord(mRecording);

    return NO_ERROR;
}