in maga_transformer/cpp/cutlass/cutlass_kernels/moe_gemm/moe_kernels.inl [1054:1232]
void CutlassMoeFCRunner<T, WeightType, QuantOp, OutputType, ScaleBiasType, Enable>::runMoe(
void const* input_activations_void,
float const* gating_output,
float const* gating_output_with_bias,
void const* fc1_expert_weights_void,
void const* fc1_expert_biases_void,
ActivationType fc1_activation_type,
void const* fc2_expert_weights_void,
void const* fc2_expert_biases_void,
QuantParams quant_params,
int64_t const num_rows,
int64_t const hidden_size,
int64_t const inter_size,
int const num_experts,
int const k,
char* workspace_ptr,
void* final_output_void,
bool const* finished,
int64_t const active_rows,
void* token_topk_final_scales_void,
int* expanded_source_row_to_expanded_dest_row,
int* expert_for_source_row,
float sparse_mixer_epsilon,
MOEParallelismConfig parallelism_config,
MOEExpertScaleNormalizationMode normalization_mode,
bool use_lora,
LoraParams& lora_params,
cudaStream_t stream) {
static constexpr bool int_scales_required
= std::is_same<WeightType, uint8_t>::value || std::is_same<WeightType, cutlass::uint4b_t>::value;
static constexpr bool fp8_scales_required
= std::is_same<WeightType, __nv_fp8_e4m3>::value || std::is_same<WeightType, __nv_fp8_e5m2>::value;
auto const* input_activations = static_cast<T const*>(input_activations_void);
auto const* fc1_expert_weights = static_cast<WeightType const*>(fc1_expert_weights_void);
auto const* fc1_expert_biases = reinterpret_cast<ScaleBiasType const*>(fc1_expert_biases_void);
auto const* fc2_expert_weights = static_cast<WeightType const*>(fc2_expert_weights_void);
auto const* fc1_int_scales = reinterpret_cast<ScaleBiasType const*>(quant_params.fc1_weight_scales);
auto const* fc1_int_zeros = reinterpret_cast<ScaleBiasType const*>(quant_params.fc1_weight_zeros);
auto const* fc2_int_scales = reinterpret_cast<ScaleBiasType const*>(quant_params.fc2_weight_scales);
auto const* fc2_int_zeros = reinterpret_cast<ScaleBiasType const*>(quant_params.fc2_weight_zeros);
int const group_size = quant_params.group_size;
auto const* fc1_fp8_dequant = quant_params.dequant_fc1;
auto const* fc2_fp8_quant = quant_params.quant_fc2;
auto const* fc2_fp8_dequant = quant_params.dequant_fc2;
auto const* input_fp8_dequant = quant_params.dequant_input;
auto const* fc2_expert_biases = reinterpret_cast<ScaleBiasType const*>(fc2_expert_biases_void);
auto* final_output = static_cast<OutputType*>(final_output_void);
auto* token_topk_unpermuted_scales = static_cast<float*>(token_topk_final_scales_void);
TLLM_CHECK_WITH_INFO(finished == nullptr, "Using 'finished' is deprecated and will be removed in future versions");
TLLM_CHECK_WITH_INFO(
num_rows == active_rows, "Using 'finished' is deprecated and will be removed in future versions");
TLLM_CHECK(input_activations);
TLLM_CHECK(fc1_expert_weights);
TLLM_CHECK(fc2_expert_weights);
TLLM_CHECK(workspace_ptr);
TLLM_CHECK(token_topk_unpermuted_scales);
TLLM_CHECK(expanded_source_row_to_expanded_dest_row);
TLLM_CHECK(expert_for_source_row);
TLLM_CHECK(num_experts % parallelism_config.ep_size == 0);
TLLM_CHECK_WITH_INFO(hidden_size >= 128 / cutlass::sizeof_bits<WeightType>::value,
"Hidden size is too small to meet alignment requirements for MOE GEMM");
TLLM_CHECK_WITH_INFO(hidden_size % (128 / cutlass::sizeof_bits<WeightType>::value) == 0,
"Hidden size does not meet minimum alignment requirements for MOE GEMM");
TLLM_CHECK_WITH_INFO(inter_size % (128 / cutlass::sizeof_bits<WeightType>::value) == 0,
"Inter size does not meet minimum alignment requirements for MOE GEMM");
// These values must fit into an int for building the source maps
TLLM_CHECK_WITH_INFO(num_rows <= std::numeric_limits<int>::max(), "Number of rows is too large");
TLLM_CHECK_WITH_INFO(
num_rows * num_experts <= std::numeric_limits<int>::max(), "Number of rows * num_experts is too large");
TLLM_CHECK_WITH_INFO(k * num_experts <= std::numeric_limits<int>::max(), "k * num_experts is too large");
TLLM_CHECK_WITH_INFO(gemm1_config_.has_value(), "MOE GEMM1 Config is not set");
TLLM_CHECK_WITH_INFO(gemm2_config_.has_value(), "MOE GEMM2 Config is not set");
if (int_scales_required)
{
TLLM_CHECK_WITH_INFO(
fc1_int_scales != nullptr, "Weight scales expected but scale for first matmul is a null pointer");
TLLM_CHECK_WITH_INFO(
fc2_int_scales != nullptr, "Weight scales expected but scale for second matmul is a null pointer");
TLLM_CHECK_WITH_INFO(fc1_fp8_dequant == nullptr && fc2_fp8_quant == nullptr && fc2_fp8_dequant == nullptr,
"FP8 scales are provided for integer quantization");
}
else if (fp8_scales_required)
{
TLLM_CHECK_WITH_INFO(fc1_expert_biases == nullptr, "Bias is not supported with FP8");
TLLM_CHECK_WITH_INFO(fc2_expert_biases == nullptr, "Bias is not supported with FP8");
TLLM_CHECK_WITH_INFO(
fc1_fp8_dequant != nullptr, "FP8 scales expected but dequant scale for FC1 is a null pointer");
TLLM_CHECK_WITH_INFO(fc2_fp8_quant != nullptr, "FP8 scales expected but quant scale for FC2 is a null pointer");
TLLM_CHECK_WITH_INFO(
fc2_fp8_dequant != nullptr, "FP8 scales expected but quant scale for FC2 is a null pointer");
TLLM_CHECK_WITH_INFO(
fc1_int_scales == nullptr && fc2_int_scales == nullptr, "Integer scales are provided for FP8 quantization");
}
else if (use_lora && use_fp8)
{
TLLM_CHECK_WITH_INFO(
input_fp8_dequant != nullptr, "FP8 scales expected but quant scale for input is a null pointer");
}
else
{
TLLM_CHECK_WITH_INFO(
fc1_int_scales == nullptr, "Scales are ignored for fp32/fp16/bf16 but received weight scale for FC1");
TLLM_CHECK_WITH_INFO(
fc2_int_scales == nullptr, "Scales are ignored for fp32/fp16/bf16 but received weight scale for FC2");
TLLM_CHECK_WITH_INFO(
fc1_fp8_dequant == nullptr, "Scales are ignored for fp32/fp16/bf16 but received dequant scale for FC1");
TLLM_CHECK_WITH_INFO(
fc2_fp8_quant == nullptr, "Scales are ignored for fp32/fp16/bf16 but received quant scale for FC2");
TLLM_CHECK_WITH_INFO(
fc2_fp8_dequant == nullptr, "Scales are ignored for fp32/fp16/bf16 but received quant scale for FC2");
}
int const num_experts_per_node = num_experts / parallelism_config.ep_size;
configureWsPtrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k,
fc1_activation_type, normalization_mode, use_lora);
int const start_expert = num_experts_per_node * parallelism_config.ep_rank;
int const end_expert = start_expert + num_experts_per_node;
genSourceRow(expert_for_source_row, source_rows_, num_rows, k, num_experts, start_expert, end_expert, stream);
sync_check_cuda_error();
sortAndScanSoftmaxOutput(expert_for_source_row, source_rows_, permuted_experts_, permuted_rows_,
expert_first_token_offset_, num_rows, num_experts, num_experts_per_node, k, sorter_,
static_cast<void*>(sorter_ws_), stream);
sync_check_cuda_error();
int64_t const expanded_num_rows = k * num_rows;
bool is_gated_activation = isGatedActivation(fc1_activation_type);
if (use_lora)
{
std::vector<int>& host_permuted_rows = host_lora_workspace_.host_permuted_rows;
std::vector<int64_t>& host_expert_first_token_offset = host_lora_workspace_.host_expert_first_token_offset;
host_permuted_rows.resize(expanded_num_rows);
TLLM_CUDA_CHECK(cudaMemcpyAsync(host_permuted_rows.data(), permuted_rows_, expanded_num_rows * sizeof(int),
cudaMemcpyDeviceToHost, stream));
host_expert_first_token_offset.resize(num_experts_per_node + 1);
TLLM_CUDA_CHECK(cudaMemcpyAsync(host_expert_first_token_offset.data(), expert_first_token_offset_,
(num_experts_per_node + 1) * sizeof(int64_t), cudaMemcpyDeviceToHost, stream));
TLLM_CUDA_CHECK(cudaEventRecord(*(lora_params.memcpy_event_ptr), stream));
}
// Actually permute the data
bool const needs_num_valid = finished || parallelism_config.ep_size > 1;
int64_t const* num_valid_tokens_ptr = needs_num_valid ? expert_first_token_offset_ + num_experts_per_node : nullptr;
expandInputRowsKernelLauncher(input_activations, permuted_data_, token_topk_unpermuted_scales, permuted_scales_,
permuted_rows_, expanded_source_row_to_expanded_dest_row, num_rows, num_valid_tokens_ptr, hidden_size, k,
stream);
sync_check_cuda_error();
Self::gemm1(moe_gemm_runner_, permuted_data_, fc1_result_, glu_inter_result_, expert_first_token_offset_,
hopper_grouped_gemm_input_, fc1_expert_weights, fc1_expert_biases, num_valid_tokens_ptr, fc1_int_scales,
fc1_int_zeros, group_size, fc1_fp8_dequant, fc2_fp8_quant, expanded_num_rows, hidden_size, inter_size,
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array_, !use_lora, stream, *gemm1_config_);
sync_check_cuda_error();
Self::gemm2(moe_gemm_runner_, fc1_result_, fc2_result_, final_output, expert_first_token_offset_,
hopper_grouped_gemm_input_, fc2_expert_weights, fc2_expert_biases, fc2_int_scales, fc2_int_zeros, group_size,
fc2_fp8_dequant, token_topk_unpermuted_scales, permuted_scales_, expanded_source_row_to_expanded_dest_row,
permuted_rows_, expert_for_source_row, num_valid_tokens_ptr, num_rows, expanded_num_rows, hidden_size,
inter_size, num_experts_per_node, k, !use_deterministic_hopper_reduce_, alpha_scale_ptr_array_, use_lora,
lora_fc2_result_, stream, parallelism_config, *gemm2_config_);
sync_check_cuda_error();
}