in maga_transformer/cpp/cutlass/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h [71:194]
void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
GemmOutputType const* weight_zeros, int group_size, GemmOutputType const* biases, bool bias_is_broadcast,
GemmOutputType* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k,
int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count,
bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, int* kernel_occupancy = nullptr)
{
#if defined(ENABLE_FP8)
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, __nv_fp8_e4m3>::value
|| cutlass::platform::is_same<T, __nv_fp8_e5m2>::value || cutlass::platform::is_same<T, float>::value,
"Specialized for fp8, bfloat16, half, float");
#elif defined(ENABLE_BF16)
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, float>::value,
"Specialized for bfloat16, half, float");
#else
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
"Specialized for half, float");
#endif
static_assert(cutlass::platform::is_same<T, WeightType>::value
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
"");
static_assert(!cutlass::platform::is_same<arch, cutlass::arch::Sm90>::value,
"Sm90 architecture should use specialised kernels");
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter<GemmOutputType>::type;
using CutlassWeightType = typename TllmToCutlassTypeAdapter<WeightType>::type;
if (!use_fused_moe)
{
// We need separate config for each architecture since we will target different tensorcore instructions. For
// float, we do not target TCs.
using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue<CutlassGemmOutputType,
MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op;
typename EpilogueOp::Params epilogue_op(
ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f));
using Operator = typename MixedGemmArchTraits::Operator;
using TaggedOperator = typename std::conditional<cutlass::platform::is_same<T, WeightType>::value,
typename MixedGemmArchTraits::Operator,
typename cutlass::arch::TagOperator<Operator, QuantOp>::TaggedOperator>::type;
#if defined(ENABLE_FP8)
if constexpr ((std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>)
&& std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefault>)
{
TLLM_CHECK_WITH_INFO(weight_scales == nullptr && biases == nullptr && alpha_scale_ptr_array,
"weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 "
"Ada");
epilogue_op.alpha_ptr_array = alpha_scale_ptr_array;
}
#endif
// Finally, set up the kernel.
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped<ElementType, cutlass::layout::RowMajor,
cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType,
typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone,
MixedGemmArchTraits::ElementsPerAccessB, CutlassGemmOutputType, cutlass::layout::RowMajor,
ElementAccumulator, typename MixedGemmArchTraits::OperatorClass, arch, ThreadblockShape, WarpShape,
typename MixedGemmArchTraits::InstructionShape, EpilogueOp,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, TaggedOperator>::GemmKernel;
using GemmKernel = cutlass::gemm::kernel::MoeFCGemm<typename GemmKernel_::Mma, typename GemmKernel_::Epilogue,
typename GemmKernel_::ThreadblockSwizzle,
arch, // Ensure top level arch is used for dispatch
GemmKernel_::kGroupScheduleMode>;
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
if (kernel_occupancy != nullptr)
{
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel>();
return;
}
int occupancy = std::min(2, GemmGrouped::maximum_active_blocks());
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel");
int const threadblock_count = multi_processor_count * occupancy;
typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op,
reinterpret_cast<ElementType const*>(A), reinterpret_cast<CutlassWeightType const*>(B),
reinterpret_cast<CutlassGemmOutputType const*>(weight_scales),
reinterpret_cast<CutlassGemmOutputType const*>(weight_zeros),
reinterpret_cast<CutlassGemmOutputType const*>(biases), bias_is_broadcast,
reinterpret_cast<CutlassGemmOutputType*>(C), total_tokens_including_expert, gemm_n, gemm_k);
GemmGrouped gemm;
auto can_implement = gemm.can_implement(args);
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
"MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
auto init_status = gemm.initialize(args);
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
"Failed to initialize cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(init_status)));
auto run_status = gemm.run(stream);
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
"Failed to run cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
}
else if constexpr (sizeof(ElementType) == 2 && sizeof(CutlassWeightType) == 2
&& (std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefaultSilu>
|| std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefaultFtGelu>) ) // use fused moe gemm
// kernel.. (only support
// fp16 or bf16)
{
sm80_generic_fused_moe_gemm_kernelLauncher<ElementType, CutlassWeightType, ThreadblockShape::kM,
ThreadblockShape::kN, ThreadblockShape::kK, Stages, EpilogueTag>(reinterpret_cast<ElementType const*>(A),
reinterpret_cast<CutlassWeightType const*>(B), reinterpret_cast<ElementType const*>(biases),
bias_is_broadcast, reinterpret_cast<ElementType*>(C), total_tokens_including_expert, num_rows, gemm_n,
gemm_k, num_experts, multi_processor_count, stream, kernel_occupancy);
}
else{
TLLM_LOG_ERROR("NOT IMPLEMENTED YET");
}
}