in maga_transformer/cpp/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h [741:871]
void MoeGemmRunner<T, WeightType, QuantOp, OutputType, ScaleBiasType>::dispatchToArch<EpilogueTag>(T const* A,
WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* weight_zeros, int group_size,
ScaleBiasType const* biases, bool bias_is_broadcast, void* C_void, int64_t const* total_tokens_including_expert,
HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array,
cudaStream_t stream, int* occupancy)
{
static_assert(std::is_same_v<ScaleBiasType, OutputType>,
"Separate Scale/Bias type is not supported. This is assumed to be the gemm output type");
// For now we always cast this to output type.
// In the future this will vary based on what fusions are applied for FP8
auto* C = reinterpret_cast<OutputType*>(C_void);
TLLM_CHECK_WITH_INFO(
sm_ >= 89 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation");
TLLM_CHECK_WITH_INFO(
sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture");
if (sm_ >= 70 && sm_ < 75)
{
if constexpr (cutlass::platform::is_same<T, __nv_bfloat16>::value)
{
TLLM_THROW("sm70 no support bf16 moe");
}
else
{
dispatchMoeGemmToCutlass<T, WeightType, QuantOp, ScaleBiasType, cutlass::arch::Sm70, EpilogueTag>(A, B,
weight_scales, weight_zeros, group_size, biases, bias_is_broadcast, C, total_tokens_including_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, use_fused_moe,
alpha_scale_ptr_array, stream, occupancy);
}
}
else if (sm_ >= 75 && sm_ < 80)
{
if constexpr (cutlass::platform::is_same<T, __nv_bfloat16>::value)
{
TLLM_THROW("sm75 no support bf16 moe");
}
else
{
dispatchMoeGemmToCutlass<T, WeightType, QuantOp, ScaleBiasType, cutlass::arch::Sm75, EpilogueTag>(A, B,
weight_scales, weight_zeros, group_size, biases, bias_is_broadcast, C, total_tokens_including_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, use_fused_moe,
alpha_scale_ptr_array, stream, occupancy);
}
}
else if (sm_ >= 80 && sm_ < 90)
{
if constexpr (use_fp8)
{
#if defined(ENABLE_FP8)
static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>,
"FP8 GEMM Output not supported");
#endif
TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89");
// dispatchMoeGemmToCutlass<T, WeightType, QuantOp, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(A, B,
// weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n,
// gemm_k, num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array,
// stream, occupancy);
}
else
{
dispatchMoeGemmToCutlass<T, WeightType, QuantOp, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(A, B,
weight_scales, weight_zeros, group_size, biases, bias_is_broadcast, C, total_tokens_including_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, use_fused_moe,
alpha_scale_ptr_array, stream, occupancy);
}
}
else if (sm_ >= 90)
{
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>())
{
// We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens
// SM80 is faster. We check here to see which is selected
if (gemm_config.is_sm90)
{
TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr,
"Input biases and hopper input disagree if bias is enabled");
TLLM_CHECK_WITH_INFO(occupancy || hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config");
// Select the appropriate fusion function
auto select_function = [&]()
{
switch (hopper_input.fusion)
{
case HopperGroupedGemmInput::EpilogueFusion::FINALIZE:
return &dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, EpilogueTag,
HopperGroupedGemmInput::EpilogueFusion::FINALIZE>;
case HopperGroupedGemmInput::EpilogueFusion::NONE:
return &dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, EpilogueTag,
HopperGroupedGemmInput::EpilogueFusion::NONE>;
case HopperGroupedGemmInput::EpilogueFusion::ACTIVATION:
case HopperGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION:
default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_input.fusion);
};
};
auto selected_func = select_function();
selected_func(
hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr);
return;
}
// Fallthrough to SM80 impl below
}
// Do Ampere case instead
if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType, EpilogueTag>())
{
TLLM_CHECK_WITH_INFO(!hopper_input.isValid(),
"Non-specialised Hopper implementation is being rerouted to fallback implementation so input "
"information is not required");
TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90,
"GEMM config is for SM90 configuration, but this configuration is not valid for Hppper");
dispatchMoeGemmToCutlass<T, WeightType, QuantOp, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(A, B,
weight_scales, weight_zeros, group_size, biases, bias_is_broadcast, C, total_tokens_including_expert,
total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count_, use_fused_moe,
alpha_scale_ptr_array, stream, occupancy);
}
else
{
TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels");
}
}
else
{
TLLM_THROW("Arch unsupported for MoE GEMM");
}
}