void MoeGemmRunner::dispatchToArch()

in maga_transformer/cpp/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h [741:871]


void MoeGemmRunner<T, WeightType, QuantOp, OutputType, ScaleBiasType>::dispatchToArch<EpilogueTag>(T const* A,
    WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* weight_zeros, int group_size,
    ScaleBiasType const* biases, bool bias_is_broadcast, void* C_void, int64_t const* total_tokens_including_expert,
    HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
    cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array,
    cudaStream_t stream, int* occupancy)
{
    static_assert(std::is_same_v<ScaleBiasType, OutputType>,
        "Separate Scale/Bias type is not supported. This is assumed to be the gemm output type");

    // For now we always cast this to output type.
    // In the future this will vary based on what fusions are applied for FP8
    auto* C = reinterpret_cast<OutputType*>(C_void);

    TLLM_CHECK_WITH_INFO(
        sm_ >= 89 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation");
    TLLM_CHECK_WITH_INFO(
        sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture");

    if (sm_ >= 70 && sm_ < 75)
    {
        if constexpr (cutlass::platform::is_same<T, __nv_bfloat16>::value)
        {
            TLLM_THROW("sm70 no support bf16 moe");
        }
        else
        {
            dispatchMoeGemmToCutlass<T, WeightType, QuantOp, ScaleBiasType, cutlass::arch::Sm70, EpilogueTag>(A, B,
                weight_scales, weight_zeros, group_size, biases, bias_is_broadcast, C, total_tokens_including_expert,
                total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, use_fused_moe,
                alpha_scale_ptr_array, stream, occupancy);
        }
    }
    else if (sm_ >= 75 && sm_ < 80)
    {
        if constexpr (cutlass::platform::is_same<T, __nv_bfloat16>::value)
        {
            TLLM_THROW("sm75 no support bf16 moe");
        }
        else
        {
            dispatchMoeGemmToCutlass<T, WeightType, QuantOp, ScaleBiasType, cutlass::arch::Sm75, EpilogueTag>(A, B,
                weight_scales, weight_zeros, group_size, biases, bias_is_broadcast, C, total_tokens_including_expert,
                total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, use_fused_moe,
                alpha_scale_ptr_array, stream, occupancy);
        }
    }
    else if (sm_ >= 80 && sm_ < 90)
    {
        if constexpr (use_fp8)
        {
#if defined(ENABLE_FP8)
            static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>,
                "FP8 GEMM Output not supported");
#endif

            TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89");
            // dispatchMoeGemmToCutlass<T, WeightType, QuantOp, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(A, B,
            //     weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n,
            //     gemm_k, num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array,
            //     stream, occupancy);
        }
        else
        {
            dispatchMoeGemmToCutlass<T, WeightType, QuantOp, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(A, B,
                weight_scales, weight_zeros, group_size, biases, bias_is_broadcast, C, total_tokens_including_expert,
                total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, use_fused_moe,
                alpha_scale_ptr_array, stream, occupancy);
        }
    }
    else if (sm_ >= 90)
    {
        if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>())
        {

            // We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens
            // SM80 is faster. We check here to see which is selected
            if (gemm_config.is_sm90)
            {
                TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr,
                    "Input biases and hopper input disagree if bias is enabled");
                TLLM_CHECK_WITH_INFO(occupancy || hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config");

                // Select the appropriate fusion function
                auto select_function = [&]()
                {
                    switch (hopper_input.fusion)
                    {
                    case HopperGroupedGemmInput::EpilogueFusion::FINALIZE:
                        return &dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, EpilogueTag,
                            HopperGroupedGemmInput::EpilogueFusion::FINALIZE>;
                    case HopperGroupedGemmInput::EpilogueFusion::NONE:
                        return &dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, EpilogueTag,
                            HopperGroupedGemmInput::EpilogueFusion::NONE>;
                    case HopperGroupedGemmInput::EpilogueFusion::ACTIVATION:
                    case HopperGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION:
                    default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_input.fusion);
                    };
                };
                auto selected_func = select_function();
                selected_func(
                    hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr);
                return;
            }

            // Fallthrough to SM80 impl below
        }

        // Do Ampere case instead
        if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType, EpilogueTag>())
        {
            TLLM_CHECK_WITH_INFO(!hopper_input.isValid(),
                "Non-specialised Hopper implementation is being rerouted to fallback implementation so input "
                "information is not required");
            TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90,
                "GEMM config is for SM90 configuration, but this configuration is not valid for Hppper");
            dispatchMoeGemmToCutlass<T, WeightType, QuantOp, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(A, B,
                weight_scales, weight_zeros, group_size, biases, bias_is_broadcast, C, total_tokens_including_expert,
                total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, use_fused_moe,
                alpha_scale_ptr_array, stream, occupancy);
        }
        else
        {
            TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels");
        }
    }
    else
    {
        TLLM_THROW("Arch unsupported for MoE GEMM");
    }
}