void cublasMMWrapper::cublasLtGemm()

in maga_transformer/cpp/cuda/cublas/cublasMMWrapper.cc [87:256]


void cublasMMWrapper::cublasLtGemm(cublasHandle_t        handle,
                                cublasOperation_t        transa,
                                cublasOperation_t        transb,
                                int                      m,
                                int                      n,
                                int                      k,
                                const void*              alpha, /* host or device pointer */
                                const void*              A,
                                const void*              A_scale,
                                cudaDataType             Atype,
                                int                      lda,
                                const void*              B,
                                const void*              B_scale,
                                cudaDataType             Btype,
                                int                      ldb,
                                const void*              beta, /* host or device pointer */
                                void*                    C,
                                cudaDataType             Ctype,
                                int                      ldc,
                                bool                     is_fp16_computeType,
                                cublasLtMatmulAlgo_info  info,
                                bool                     findAlgo,
                                int                      math_sm_count,
                                int8_t                   fast_accum,
                                cudaStream_t             stream) {
    cublasLtMatrixLayout_t Adesc;
    cublasLtMatrixLayout_t Bdesc;
    cublasLtMatrixLayout_t Cdesc;
    cublasLtMatrixLayout_t Ddesc;
    cublasLtMatmulDesc_t operationDesc;

    cudaDataType_t         scaleType;
#if (CUDART_VERSION >= 11000)
    cublasComputeType_t computeType;
#else
    cudaDataType_t computeType;
#endif

    if (is_fp16_computeType) {
#if (CUDART_VERSION >= 11000)
        computeType = CUBLAS_COMPUTE_16F;
#else
        computeType = CUDA_R_16F;
#endif
        scaleType = CUDA_R_16F;
    } else {
#if (CUDART_VERSION >= 11000)
        computeType = CUBLAS_COMPUTE_32F;
#else
        computeType = CUDA_R_32F;
#endif
        scaleType = CUDA_R_32F;
    }

    // --------------------------------------
    // Create descriptors for the original matrices
    check_cuda_error(cublasLtMatrixLayoutCreate(&Adesc, Atype, 
            transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
    FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Adesc); });
    check_cuda_error(cublasLtMatrixLayoutCreate(&Bdesc, Btype,
            transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
    FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Bdesc); });
    check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc));
    FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Cdesc); });
    check_cuda_error(cublasLtMatrixLayoutCreate(&Ddesc, Btype, m, n, ldc));
    FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Ddesc); });
#if (CUDART_VERSION >= 11000)
    check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
#else
    check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType));
#endif    
    FT_SCOPE_GUARD([&](){ cublasLtMatmulDescDestroy(operationDesc); });

    if (math_sm_count > 0) {
        check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, sizeof(math_sm_count)));
    }

    check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
            CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
    check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
            CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));

    check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc, 
            CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fast_accum, sizeof(int8_t)));
    if (A_scale != nullptr) {
        check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
                    CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &A_scale, sizeof(void*)));
    }
    if (B_scale != nullptr) {
        check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
                    CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale, sizeof(void*)));
    }

  
    cublasLtMatmulAlgo_t algo;
    void*                workSpace     = cublas_workspace_;
    uint64_t             workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
    if (stream != stream_) {
        if (cublas_workspces_map_.count(stream) == 0) {
            void* additional_cublas_workspace = nullptr;
            additional_cublas_workspace = allocator_->reMalloc(additional_cublas_workspace, CUBLAS_WORKSPACE_SIZE);
            additional_cublas_workspaces_.push_back(additional_cublas_workspace);
            cublas_workspces_map_[stream] = additional_cublas_workspaces_.size() - 1;
        }
        workSpace = additional_cublas_workspaces_[cublas_workspces_map_[stream]];
        workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
        RTP_LLM_LOG_DEBUG("stream %d, idx %d", stream, cublas_workspces_map_[stream]);
    }
    if (findAlgo) {
        if (info.workspaceSize > workspaceSize) {
            findAlgo = 0;
        } else {
            check_cuda_error(cublasLtMatmulAlgoInit(
                cublaslt_handle_, computeType, scaleType, Atype, Btype, Ctype, Ctype, info.algoId, &algo));
            check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
                &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)));
            check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
                &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)));
            check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
                &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)));
            check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
                &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)));
            check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(&algo,
                                                    CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
                                                    &(info.reductionScheme),
                                                    sizeof(info.reductionScheme)));

#if (CUDART_VERSION >= 11000)
            check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
                &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)));
#endif

#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
            check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
                &algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId)));
            check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(&algo,
                                                    CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID,
                                                    &(info.cluster_shapeId),
                                                    sizeof(info.cluster_shapeId)));
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
            check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
                &algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId)));
            check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
                &algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId)));
            check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
                &algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode)));
#endif
        }
    }

    check_cuda_error(cublasLtMatmulWrapper(cublaslt_handle_,
                    operationDesc,
                    alpha,
                    A,
                    Adesc,
                    B,
                    Bdesc,
                    beta,
                    C,
                    Cdesc,
                    C,
                    Cdesc,
                    (findAlgo == 1 ? (&algo) : NULL),
                    workSpace,
                    workspaceSize,
                    stream,
                    /* find_best = */false));

    sync_check_cuda_error();
}