void int4Gemm_impl()

in maga_transformer/cpp/rocm/int4_gemm_kernels/int4_dequant_comm.h [123:201]


void int4Gemm_impl(const ckGemmParam& params) {
    // Get input information.
    auto M            = params.M;
    auto N            = params.N;
    auto K            = params.K;
    auto GroupSize    = params.Group_size;
    auto StrideA      = params.StrideA;
    auto StrideB      = params.StrideB;
    auto StrideC      = params.StrideC;
    auto StrideScaleB = (K + params.Group_size - 1) / params.Group_size;
    auto KBatch       = 1;

    // for KBatch tuning
    if (N == 29696 && K == 8192) {
        if (M == 1 || M == 16 || M == 48 || M == 64) {
            KBatch = 2;
        } else if (M == 32) {
            KBatch = 4;
        }
    } else if (N == 8192 && K == 29696) {
        if (M == 64) {
            KBatch = 2;
        } else if (M == 16 || M == 48) {
            KBatch = 4;
        } else if (M == 1 || M == 32 || M == 80 || M == 96 || M == 112 || M == 128) {
            KBatch = 8;
        }
    } else if (N == 10240 && K == 8192) {
        if (M == 16 || M == 80 || M == 112 || M == 128) {
            KBatch = 2;
        } else if (M == 1 || M == 32) {
            KBatch = 4;
        }
    } else if (N == 8192 && K == 8192) {
        if (M == 112 || M == 128) {
            KBatch = 2;
        } else if (M == 1 || M == 16) {
            KBatch = 4;
        }
    }

    // Create gemm launcher and arguments.
    auto gemm    = DeviceInt4GemmInstance{};
    auto invoker = gemm.MakeInvoker();

    auto a_element_op = AElementOp{};
    auto b_element_op = BElementOp{};
    auto c_element_op = CElementOp{};

    auto argument = gemm.MakeArgument(static_cast<ADataType*>(params.A_input),
                                      static_cast<BDataType*>(params.B_input),
                                      static_cast<CDataType*>(params.C_input),
                                      M,
                                      N,
                                      K,
                                      StrideA,
                                      StrideB,
                                      StrideC,
                                      StrideScaleB,
                                      static_cast<BScaleDataType*>(params.B_scales_input),
                                      KBatch,
                                      a_element_op,
                                      b_element_op,
                                      c_element_op);

    if (!gemm.IsSupportedArgument(argument)) {
        std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;

        return;
    }

    std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
    if (workspace_size != 0) {
        DeviceMem gemm_desc_workspace(workspace_size);
        gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
    }

    invoker.Run(argument, StreamConfig{params.stream, false});
}