void genericMoeGemmKernelLauncher()

in maga_transformer/cpp/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h [71:194]


void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
    GemmOutputType const* weight_zeros, int group_size, GemmOutputType const* biases, bool bias_is_broadcast,
    GemmOutputType* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k,
    int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count,
    bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, int* kernel_occupancy = nullptr)
{
#if defined(ENABLE_FP8)
    static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
            || cutlass::platform::is_same<T, __nv_fp8_e4m3>::value
            || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value || cutlass::platform::is_same<T, float>::value,
        "Specialized for fp8, bfloat16, half, float");
#elif defined(ENABLE_BF16)
    static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
            || cutlass::platform::is_same<T, float>::value,
        "Specialized for bfloat16, half, float");
#else
    static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
        "Specialized for half, float");
#endif

    static_assert(cutlass::platform::is_same<T, WeightType>::value
            || cutlass::platform::is_same<WeightType, uint8_t>::value
            || cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
        "");

    static_assert(!cutlass::platform::is_same<arch, cutlass::arch::Sm90>::value,
        "Sm90 architecture should use specialised kernels");

    // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
    using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
    using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter<GemmOutputType>::type;
    using CutlassWeightType = typename TllmToCutlassTypeAdapter<WeightType>::type;
    if (!use_fused_moe)
    {
        // We need separate config for each architecture since we will target different tensorcore instructions. For
        // float, we do not target TCs.
        using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
        using ElementAccumulator = typename MixedGemmArchTraits::AccType;

        using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue<CutlassGemmOutputType,
            MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op;

        typename EpilogueOp::Params epilogue_op(
            ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f));

        using Operator = typename MixedGemmArchTraits::Operator;
        using TaggedOperator = typename std::conditional<cutlass::platform::is_same<T, WeightType>::value,
            typename MixedGemmArchTraits::Operator,
            typename cutlass::arch::TagOperator<Operator, QuantOp>::TaggedOperator>::type;

#if defined(ENABLE_FP8)
        if constexpr ((std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>)
            && std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefault>)
        {
            TLLM_CHECK_WITH_INFO(weight_scales == nullptr && biases == nullptr && alpha_scale_ptr_array,
                "weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 "
                "Ada");
            epilogue_op.alpha_ptr_array = alpha_scale_ptr_array;
        }
#endif

        // Finally, set up the kernel.
        using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped<ElementType, cutlass::layout::RowMajor,
            cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType,
            typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone,
            MixedGemmArchTraits::ElementsPerAccessB, CutlassGemmOutputType, cutlass::layout::RowMajor,
            ElementAccumulator, typename MixedGemmArchTraits::OperatorClass, arch, ThreadblockShape, WarpShape,
            typename MixedGemmArchTraits::InstructionShape, EpilogueOp,
            cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages,
            cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, TaggedOperator>::GemmKernel;

        using GemmKernel = cutlass::gemm::kernel::MoeFCGemm<typename GemmKernel_::Mma, typename GemmKernel_::Epilogue,
            typename GemmKernel_::ThreadblockSwizzle,
            arch, // Ensure top level arch is used for dispatch
            GemmKernel_::kGroupScheduleMode>;

        using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;

        if (kernel_occupancy != nullptr)
        {
            *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel>();
            return;
        }
        int occupancy = std::min(2, GemmGrouped::maximum_active_blocks());
        TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel");
        int const threadblock_count = multi_processor_count * occupancy;

        typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op,
            reinterpret_cast<ElementType const*>(A), reinterpret_cast<CutlassWeightType const*>(B),
            reinterpret_cast<CutlassGemmOutputType const*>(weight_scales),
            reinterpret_cast<CutlassGemmOutputType const*>(weight_zeros),
            reinterpret_cast<CutlassGemmOutputType const*>(biases), bias_is_broadcast,
            reinterpret_cast<CutlassGemmOutputType*>(C), total_tokens_including_expert, gemm_n, gemm_k);

        GemmGrouped gemm;

        auto can_implement = gemm.can_implement(args);
        TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
            "MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));

        auto init_status = gemm.initialize(args);
        TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
            "Failed to initialize cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(init_status)));

        auto run_status = gemm.run(stream);
        TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
            "Failed to run cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
    }
    else if constexpr (sizeof(ElementType) == 2 && sizeof(CutlassWeightType) == 2
        && (std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefaultSilu>
            || std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefaultFtGelu>) ) // use fused moe gemm
                                                                                           // kernel.. (only support
                                                                                           // fp16 or bf16)
    {
        sm80_generic_fused_moe_gemm_kernelLauncher<ElementType, CutlassWeightType, ThreadblockShape::kM,
            ThreadblockShape::kN, ThreadblockShape::kK, Stages, EpilogueTag>(reinterpret_cast<ElementType const*>(A),
            reinterpret_cast<CutlassWeightType const*>(B), reinterpret_cast<ElementType const*>(biases),
            bias_is_broadcast, reinterpret_cast<ElementType*>(C), total_tokens_including_expert, num_rows, gemm_n,
            gemm_k, num_experts, multi_processor_count, stream, kernel_occupancy);
    }
    else{
        TLLM_LOG_ERROR("NOT IMPLEMENTED YET");
    }
}