std::vector getGemmParams()

in source/backend/opencl/core/OpenCLGemmTune.cpp [181:559]


std::vector<uint32_t> getGemmParams(const std::vector<uint32_t> &gemmSize, const std::vector<cl::Buffer> tensorMemory,
                                       OpenCLRuntime *runtime, int precision, int tuneLevel) {
    MNN_ASSERT(gemmSize.size() == 6); // M, N, K, Layout+Precision, Batch, Bias+GroupSize
    MNN_ASSERT(gemmSize[0] % 16 == 0);
    MNN_ASSERT(gemmSize[1] % 16 == 0);
    MNN_ASSERT(gemmSize[2] % 4 == 0);

    int layoutType = gemmSize[3] % 10;
    int mixPrecision = gemmSize[3] / 10;
    int biasType = gemmSize[5] % 10;
    int groupSize = gemmSize[5] / 10 + 1;
    MNN_ASSERT((biasType == 0 && tensorMemory.size() == 3) || (biasType >= 1 && tensorMemory.size() == 4));
    auto& tunedGemmParams = runtime->tunedGemmParamsMap();
    auto& tuneLws = runtime->getTuneLwsMap();
    
    std::vector<uint32_t> info(gemmSize);
    uint32_t precisionType = precision;
    if(precisionType == 2 && mixPrecision > 0) {
        precisionType = 0;
    }
    info.emplace_back(precisionType);
    
    if (tunedGemmParams.find(info) != tunedGemmParams.end()) {
        return tunedGemmParams[info];
    }
    
    auto getMaxDivisor = [](uint32_t num) -> uint32_t {
        std::vector<int> divisors = {128, 64, 32};
        for (const auto& divisor : divisors) {
            if (num % divisor == 0) {
                return divisor;
            }
        }
        return 16;
    };
    
    // top gpu device and large computation
    if(runtime->getGpuLevel() >= MEDIUM){
        // total computation
        auto compute_ratio = 1.0 * gemmSize[4] * gemmSize[0] / 256.0 * gemmSize[1] / 256.0 * gemmSize[2] / 256.0;
        auto thread_ratio = 1.0 * gemmSize[4] * gemmSize[0] / 256.0 * gemmSize[1] / 256.0;
            
        // each dimension is even
        bool is_even =  gemmSize[0] >= 256 && gemmSize[1] >= 128 && gemmSize[2] >= 128;
        is_even |= gemmSize[1] >= 128 && gemmSize[2] >= 128 && gemmSize[4] >= 4;
        bool is_div = gemmSize[0] % 64 == 0 && gemmSize[1] % 32 == 0;
        if(compute_ratio >= 1.0 && thread_ratio >= 1.0 && is_even && is_div) {
            int maxDivsorM = getMaxDivisor(gemmSize[0]);
            int maxDivsorN = getMaxDivisor(gemmSize[1]);
            maxDivsorM = maxDivsorM > 64 ? 64 : maxDivsorM;
            maxDivsorN = maxDivsorN > 32 ? 32 : maxDivsorN;
            std::vector<uint32_t> params_prefer = {16, 2, 16, 16, 64, 8, 8, 32, 0, 0, 0, 0, 4, 4};
            params_prefer[2] = maxDivsorM / 4;
            params_prefer[3] = maxDivsorM / 4;
            params_prefer[4] = maxDivsorM;
            params_prefer[5] = maxDivsorN / 4;
            params_prefer[6] = maxDivsorN / 4;
            params_prefer[7] = maxDivsorN;
                
            return params_prefer;
        }
    }
    if(runtime->getGpuLevel() == TOP && (tuneLevel == None || tuneLevel == Fast)) {
        // total computation
        auto compute_ratio = 1.0 * gemmSize[4] * gemmSize[0] / 512.0 * gemmSize[1] / 512.0 * gemmSize[2] / 512.0;
        auto thread_ratio = 1.0 * gemmSize[4] * gemmSize[0] / 512.0 * gemmSize[1] / 512.0;

        // each dimension is even
        bool is_even =  gemmSize[0] >= 512 && gemmSize[1] >= 256 && gemmSize[2] >= 256;
        is_even |= gemmSize[1] >= 128 && gemmSize[2] >= 128 && gemmSize[4] >= 4;
        bool is_div = gemmSize[0] % 64 == 0 && gemmSize[1] % 64 == 0;
        if(compute_ratio >= 1.0 && thread_ratio >= 1.0 && is_even && is_div) {
            int maxDivsorM = getMaxDivisor(gemmSize[0]);
            int maxDivsorN = getMaxDivisor(gemmSize[1]);
            std::vector<uint32_t> params_prefer = {16, 2, 16, 16, 128, 16, 16, 128, 0, 0, 0, 0, 8, 8};
            params_prefer[4] = maxDivsorM;
            params_prefer[7] = maxDivsorN;
            params_prefer[12] = maxDivsorM / 16;
            params_prefer[13] = maxDivsorN / 16;
            
            return params_prefer;
        }
    }
    std::vector<uint32_t> tuneLwsRes;
    if(GemmlocalWSTune(tuneLws, gemmSize, tuneLwsRes, runtime, precision)){
        return tuneLwsRes;
    }
    
    std::vector<uint32_t> params_prefer = {16, 2, 4, 4, 16, 4, 4, 16, 0, 0, 1, 0, 2, 2};

    auto thread_ratio = 1.0 * gemmSize[4] * gemmSize[0] / 512.0 * gemmSize[1] / 512.0;
    bool is_div = gemmSize[0] % 64 == 0 && gemmSize[1] % 32 == 0;
    // init params with pretty suitable candidate to avoid to slow initial
    if(thread_ratio >= 1.0 && is_div) {
        params_prefer.assign({16, 2, 16, 16, 64 , 8 , 8 , 32 , 0, 0, 0, 0, 4, 4});
    }
    
    if(tuneLevel == None) {
        float multiNum = 1.0 * gemmSize[0] / 512.0 * gemmSize[1] / 512.0 * gemmSize[2] / 512.0;
        int maxDivsorM = getMaxDivisor(gemmSize[0]);
        int maxDivsorN = getMaxDivisor(gemmSize[1]);
        
        if(gemmSize[4] == 1) {// Gemm
            if(gemmSize[0] >= 256 && gemmSize[1] >= 256 && gemmSize[2] >= 256) {
                if(multiNum > 8.0) {
                    if(maxDivsorM >= 128 && maxDivsorN >= 64) {
                        return {16, 2, 16, 16, 128, 8, 8, 64, 0, 0, 0, 1, 8, 8};
                    }
                }
                if(maxDivsorM >= 64 && maxDivsorN >= 64) {
                    return {16, 2, 8, 8, 64, 8, 8, 64, 0, 0, 0, 1, 8, 8};
                }
            }
        } else {// BatchGemm
            if(maxDivsorM >= 64 && maxDivsorN >= 128) {
                return {16, 2, 16, 16, 64, 8, 8, 128, 0, 0, 1, 0, 2, 8};
            } else if(maxDivsorM >= 64 && maxDivsorN >= 64) {
                return {16, 2, 8, 8, 64, 8, 8, 64, 0, 0, 1, 0, 4, 4};
            }
        }
        return params_prefer;
    }

    std::vector<std::vector<uint32_t>> totalCombinations; // save total candidate combinations
    totalCombinations.emplace_back(params_prefer);
    uint32_t min_cost = UINT_MAX;
    
    if(tuneLevel >= Wide) {
        // set candidates=
        totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 128, 0, 0, 0, 0, 4, 8});//12
        totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 0, 0, 8, 8});//11 ..
        totalCombinations.push_back({16, 2, 16, 16, 128, 16, 16, 128, 0, 0, 0, 0, 8, 8});//1
        totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 32 , 0, 0, 0, 1, 8, 4});//1
        totalCombinations.push_back({16, 2, 8 , 8 , 16 , 8 , 8 , 64, 0, 0, 0, 0, 2, 8});
        totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 128, 0, 0, 0, 1, 4, 8});//10

        totalCombinations.push_back({16, 2, 8,  8 , 32 , 8 , 8 , 128, 0, 0, 1, 0, 2, 8});//2
        totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 128, 0, 0, 1, 1, 2, 8});//12
        totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 2, 8});//2
        totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 0, 0, 8, 8});
        totalCombinations.push_back({16, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 0, 0, 2, 8});
        totalCombinations.push_back({16, 2, 4, 4, 32, 8, 8, 32, 0, 0, 0, 0, 8, 2});
        totalCombinations.push_back({16, 2, 4, 4, 16, 8, 8, 32, 0, 0, 0, 0, 4, 2});

        if(tuneLevel < Fast) {
            totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 8, 8});//4
            totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 0, 1, 8, 8});//6
            totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 8, 8});//4
    
            totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 2, 8});//3
            totalCombinations.push_back({16, 2, 8,  8 , 64 , 8 , 8 , 64 , 0, 0, 1, 0, 2, 8});//1
            totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 4, 4});//1
            totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 128, 0, 0, 1, 0, 2, 8});//3
            
            totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 32 , 0, 0, 0, 0, 4, 4});//1
            totalCombinations.push_back({16, 2, 16, 16, 128, 16, 16, 128, 0, 0, 0, 1, 8, 8});//2
            totalCombinations.push_back({16, 2, 16, 16, 128, 16, 16, 128, 0, 0, 1, 0, 8, 8});//1
            totalCombinations.push_back({16, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 1, 0, 2, 8});//1
            totalCombinations.push_back({16, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 1, 1, 2, 8});//1
            
            totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 32 , 0, 0, 0, 1, 4, 4});//1
            totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 32 , 0, 0, 1, 0, 4, 4});
            totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 4, 8});
            totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 0, 1, 8, 8});
            totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 1, 1, 8, 8});
            
            totalCombinations.push_back({16, 2, 8, 8, 32, 8, 8, 32, 0, 0, 1, 0, 2, 4});
            totalCombinations.push_back({16, 2, 8, 8, 16, 8, 8, 32, 0, 0, 1, 1, 2, 4});
            totalCombinations.push_back({16, 2, 4, 4, 16, 8, 8, 64, 0, 0, 0, 0, 2, 8});
            totalCombinations.push_back({16, 2, 4, 4, 64, 8, 8, 32, 0, 0, 1, 0, 4, 4});
            totalCombinations.push_back({16, 2, 4, 4, 32, 8, 8, 64, 0, 0, 0, 1, 2, 4});
        }
    } else {
        // get all combinations
        std::vector<std::vector<uint32_t>> candidates = {
            {16, 32},         // KWG
            {2},              // KWI
            {4, 8, 16},          // MDIMA
            {4, 8, 16},          // MDIMC
            {16, 32, 64, 128}, // MWG
            {8, 16},          // NDIMB
            {8, 16},          // NDIMC
            {16, 32, 64, 128}, // NWG
            {0},              // SA
            {0},              // SB
            {0, 1},           // STRM
            {0, 1},           // STRN
            {2, 4, 8},        // VWM
            {2, 4, 8}        // VWN
        };
        
        std::vector<uint32_t> currentCombination(candidates.size());
        generateCombinations(candidates, currentCombination, totalCombinations, 0);
    }
    for(int i = 0; i < totalCombinations.size(); i++) {
        uint32_t kwg   = totalCombinations[i][0];
        uint32_t kwi   = totalCombinations[i][1];
        uint32_t mdima = totalCombinations[i][2];
        uint32_t mdimc = totalCombinations[i][3];
        uint32_t mwg   = totalCombinations[i][4];
        uint32_t ndimb = totalCombinations[i][5];
        uint32_t ndimc = totalCombinations[i][6];
        uint32_t nwg   = totalCombinations[i][7];
        uint32_t sa    = totalCombinations[i][8];
        uint32_t sb    = totalCombinations[i][9];
        uint32_t strm  = totalCombinations[i][10];
        uint32_t strn  = totalCombinations[i][11];
        uint32_t vwm   = totalCombinations[i][12];
        uint32_t vwn   = totalCombinations[i][13];
        
        if(isCandidateValid(kwg, kwi, mwg, mdimc, vwm, nwg, ndimc, vwn, mdima, ndimb, sa, sb, runtime, gemmSize, precision)) {
            
            std::set<std::string> buildOptions;
            buildOptions.clear();
            buildOptions.emplace("-DKWG="   + std::to_string(kwg));
            buildOptions.emplace("-DKWI="   + std::to_string(kwi));
            buildOptions.emplace("-DMDIMA=" + std::to_string(mdima));
            buildOptions.emplace("-DMDIMC=" + std::to_string(mdimc));
            buildOptions.emplace("-DMWG="   + std::to_string(mwg));
            buildOptions.emplace("-DNDIMB=" + std::to_string(ndimb));
            buildOptions.emplace("-DNDIMC=" + std::to_string(ndimc));
            buildOptions.emplace("-DNWG="   + std::to_string(nwg));
            buildOptions.emplace("-DSA="    + std::to_string(sa));
            buildOptions.emplace("-DSB="    + std::to_string(sb));
            buildOptions.emplace("-DSTRM="  + std::to_string(strm));
            buildOptions.emplace("-DSTRN="  + std::to_string(strn));
            buildOptions.emplace("-DVWM="   + std::to_string(vwm));
            buildOptions.emplace("-DVWN="   + std::to_string(vwn));
            
            if(layoutType >= 4) {
                buildOptions.emplace(" -DOUTPUTMN");
            }
            if(runtime->getGpuType() == GpuType::ADRENO) {
                buildOptions.emplace(" -DUSE_CL_MAD=1");
                buildOptions.emplace(" -DRELAX_WORKGROUP_SIZE=1");
            }
            
            if(biasType >= 1) {
                buildOptions.emplace(" -DBIAS_TYPE=" + std::to_string((int)biasType));
            }
            if(mixPrecision > 0) {
                buildOptions.emplace("-DPRECISION_COMPUTE=float -DCONVERT_PRECISION_COMPUTE=convert_float");
                buildOptions.emplace("-DPRECISION_COMPUTE2=float2 -DCONVERT_PRECISION_COMPUTE2=convert_float2");
                buildOptions.emplace("-DPRECISION_COMPUTE4=float4 -DCONVERT_PRECISION_COMPUTE4=convert_float4");
                buildOptions.emplace("-DPRECISION_COMPUTE8=float8 -DCONVERT_PRECISION_COMPUTE8=convert_float8");
                buildOptions.emplace("-DPRECISION_COMPUTE16=float16 -DCONVERT_PRECISION_COMPUTE16=convert_float16");
            }

            int localM = mdimc;
            int localN = ndimc;
            
            std::shared_ptr<KernelWrap> kernel;
            if(gemmSize[4] > 1) {
                kernel =    runtime->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions, precision);
            } else {
                kernel = runtime->buildKernel("matmul_params_buf", "Xgemm", buildOptions, precision);
            }
            if(kernel == nullptr) {
                continue;
            }
            if(localM * localN > runtime->getMaxWorkGroupSize(kernel)) {
                continue;
            }
            int tileM = mwg;
            int tileN = nwg;
            int out_per_thread_m = tileM / localM;
            int out_per_thread_n = tileN / localN;
            
            std::vector<uint32_t>  globalWorkSize = {static_cast<uint32_t>(gemmSize[0]/out_per_thread_m), static_cast<uint32_t>(gemmSize[1]/out_per_thread_n), gemmSize[4]};
            std::vector<uint32_t>  localWorkSize = {static_cast<uint32_t>(localM), static_cast<uint32_t>(localN), 1};
            
            float alpha = 1.0;
            float beta = 0.0f;
            // A: [n, l, e]
            // B: [n, l, h]
            
            int cost_time;
            int idx = 0;
            cl_int ret = CL_SUCCESS;
            ret |= kernel->get().setArg(idx++, static_cast<int>(gemmSize[0]));
            ret |= kernel->get().setArg(idx++, static_cast<int>(gemmSize[1]));
            ret |= kernel->get().setArg(idx++, static_cast<int>(gemmSize[2]));
            ret |= kernel->get().setArg(idx++, alpha);
            ret |= kernel->get().setArg(idx++, beta);
            
            int stride[4] = {(int)gemmSize[0], (int)gemmSize[1], (int)gemmSize[1], (int)gemmSize[1]};
            if(layoutType < 4) {
                stride[2] = gemmSize[0]; // output: [N, M]
            }
            if(gemmSize[4] > 1) {
                int batch_offset_a = gemmSize[0] * gemmSize[2];
                int batch_offset_b = gemmSize[1] * gemmSize[2];
                int batch_offset_c = gemmSize[0] * gemmSize[1];
                int batch_offset[4] = {batch_offset_a, batch_offset_b, batch_offset_c, 0};
                int base_ptr_offset[4] = {0, 0, 0, 0};

                int group[4] = {1, (int)groupSize, 1, (int)gemmSize[4]};

                ret |= kernel->get().setArg(idx++, tensorMemory[0]);
                ret |= kernel->get().setArg(idx++, tensorMemory[1]);
                if(biasType > 0) {
                    ret |= kernel->get().setArg(idx++, tensorMemory[3]);
                }
                ret |= kernel->get().setArg(idx++, tensorMemory[2]);
                ret |= kernel->get().setArg(idx++, sizeof(batch_offset), batch_offset);
                ret |= kernel->get().setArg(idx++, sizeof(batch_offset), base_ptr_offset);
                ret |= kernel->get().setArg(idx++, sizeof(stride), stride);
                ret |= kernel->get().setArg(idx++, sizeof(group), group);

                MNN_CHECK_CL_SUCCESS(ret, "setArg getGemmParams XgemmBatchhed Kernel");
                
                cl::Event event;
                auto res = CL_SUCCESS;
                res = runtime->commandQueue().enqueueNDRangeKernel(kernel->get(), cl::NullRange, {globalWorkSize[0], globalWorkSize[1], globalWorkSize[2]}, {localWorkSize[0], localWorkSize[1], localWorkSize[2]}, nullptr, &event);
                if (res != CL_SUCCESS) {
                    MNN_PRINT("XgemmBatched params tune error: %d\n", res);
                    continue;
                }

                cost_time = (int)runtime->getCostTime(&event);
            } else {
                int offset_a = 0;
                int offset_b = 0;
                int offset_c = 0;
                int offset[4] = {0, 0, 0, 0};

                ret |= kernel->get().setArg(idx++, tensorMemory[0]);
                ret |= kernel->get().setArg(idx++, tensorMemory[1]);
                if(biasType >= 1) {
                    ret |= kernel->get().setArg(idx++, tensorMemory[3]);
                }
                ret |= kernel->get().setArg(idx++, tensorMemory[2]);
                ret |= kernel->get().setArg(idx++, offset);
                ret |= kernel->get().setArg(idx++, stride);
                
                MNN_CHECK_CL_SUCCESS(ret, "setArg getGemmParams Xgemm Kernel");
                
                cl::Event event;
                auto res = CL_SUCCESS;
                res = runtime->commandQueue().enqueueNDRangeKernel(kernel->get(), cl::NullRange, {globalWorkSize[0], globalWorkSize[1]}, {localWorkSize[0], localWorkSize[1]}, nullptr, &event);
                if (res != CL_SUCCESS) {
                    MNN_PRINT("Xgemm params tune error: %d\n", res);
                    continue;
                }
                cost_time = (int)runtime->getCostTime(&event);
            }
            
            if(cost_time > 0 && cost_time < min_cost) {
                min_cost = cost_time;
                params_prefer[0]  = kwg;
                params_prefer[1]  = kwi;
                params_prefer[2]  = mdima;
                params_prefer[3]  = mdimc;
                params_prefer[4]  = mwg;
                params_prefer[5]  = ndimb;
                params_prefer[6]  = ndimc;
                params_prefer[7]  = nwg;
                params_prefer[8]  = sa;
                params_prefer[9]  = sb;
                params_prefer[10] = strm;
                params_prefer[11] = strn;
                params_prefer[12] = vwm;
                params_prefer[13] = vwn;
                #ifdef TIME_TUNE_LOG
                for(auto &iter : params_prefer) {
                    MNN_PRINT("%d ", iter);
                }
                MNN_PRINT(": %d us, shape:%d %d %d batch:%d, layout:%d bias:%d, flops:%f GFLOPS\n", min_cost, gemmSize[0], gemmSize[1], gemmSize[2], gemmSize[4], gemmSize[3], gemmSize[5], 2.0 / 1000.0 * gemmSize[0] * gemmSize[1] * gemmSize[2] * gemmSize[4] / min_cost);
                #endif
            }
        }
    }
  
    if (tunedGemmParams.find(info) == tunedGemmParams.end()) {
        tunedGemmParams.insert(std::make_pair(info, params_prefer));
    }

    return params_prefer;
}