void CutlassMoeFCRunner::runMoe()

in maga_transformer/cpp/cutlass/cutlass_kernels/moe_gemm/moe_kernels.inl [1054:1232]


void CutlassMoeFCRunner<T, WeightType, QuantOp, OutputType, ScaleBiasType, Enable>::runMoe(
    void const*                     input_activations_void,
    float const*                    gating_output,
    float const*                    gating_output_with_bias,
    void const*                     fc1_expert_weights_void,
    void const*                     fc1_expert_biases_void,
    ActivationType                  fc1_activation_type,
    void const*                     fc2_expert_weights_void,
    void const*                     fc2_expert_biases_void,
    QuantParams                     quant_params,
    int64_t const                   num_rows,
    int64_t const                   hidden_size,
    int64_t const                   inter_size,
    int const                       num_experts,
    int const                       k,
    char*                           workspace_ptr,
    void*                           final_output_void,
    bool const*                     finished,
    int64_t const                   active_rows,
    void*                           token_topk_final_scales_void,
    int*                            expanded_source_row_to_expanded_dest_row,
    int*                            expert_for_source_row,
    float                           sparse_mixer_epsilon,
    MOEParallelismConfig            parallelism_config,
    MOEExpertScaleNormalizationMode normalization_mode,
    bool                            use_lora,
    LoraParams&                     lora_params,
    cudaStream_t                    stream) {
    static constexpr bool int_scales_required
        = std::is_same<WeightType, uint8_t>::value || std::is_same<WeightType, cutlass::uint4b_t>::value;
    static constexpr bool fp8_scales_required
        = std::is_same<WeightType, __nv_fp8_e4m3>::value || std::is_same<WeightType, __nv_fp8_e5m2>::value;

    auto const* input_activations = static_cast<T const*>(input_activations_void);
    auto const* fc1_expert_weights = static_cast<WeightType const*>(fc1_expert_weights_void);
    auto const* fc1_expert_biases = reinterpret_cast<ScaleBiasType const*>(fc1_expert_biases_void);
    auto const* fc2_expert_weights = static_cast<WeightType const*>(fc2_expert_weights_void);
    auto const* fc1_int_scales = reinterpret_cast<ScaleBiasType const*>(quant_params.fc1_weight_scales);
    auto const* fc1_int_zeros = reinterpret_cast<ScaleBiasType const*>(quant_params.fc1_weight_zeros);
    auto const* fc2_int_scales = reinterpret_cast<ScaleBiasType const*>(quant_params.fc2_weight_scales);
    auto const* fc2_int_zeros = reinterpret_cast<ScaleBiasType const*>(quant_params.fc2_weight_zeros);
    int const group_size = quant_params.group_size;

    auto const* fc1_fp8_dequant = quant_params.dequant_fc1;
    auto const* fc2_fp8_quant = quant_params.quant_fc2;
    auto const* fc2_fp8_dequant = quant_params.dequant_fc2;
    auto const* input_fp8_dequant = quant_params.dequant_input;
    auto const* fc2_expert_biases = reinterpret_cast<ScaleBiasType const*>(fc2_expert_biases_void);
    auto* final_output = static_cast<OutputType*>(final_output_void);
    auto* token_topk_unpermuted_scales = static_cast<float*>(token_topk_final_scales_void);

    TLLM_CHECK_WITH_INFO(finished == nullptr, "Using 'finished' is deprecated and will be removed in future versions");
    TLLM_CHECK_WITH_INFO(
        num_rows == active_rows, "Using 'finished' is deprecated and will be removed in future versions");
    TLLM_CHECK(input_activations);
    TLLM_CHECK(fc1_expert_weights);
    TLLM_CHECK(fc2_expert_weights);
    TLLM_CHECK(workspace_ptr);
    TLLM_CHECK(token_topk_unpermuted_scales);
    TLLM_CHECK(expanded_source_row_to_expanded_dest_row);
    TLLM_CHECK(expert_for_source_row);
    TLLM_CHECK(num_experts % parallelism_config.ep_size == 0);
    TLLM_CHECK_WITH_INFO(hidden_size >= 128 / cutlass::sizeof_bits<WeightType>::value,
        "Hidden size is too small to meet alignment requirements for MOE GEMM");
    TLLM_CHECK_WITH_INFO(hidden_size % (128 / cutlass::sizeof_bits<WeightType>::value) == 0,
        "Hidden size does not meet minimum alignment requirements for MOE GEMM");
    TLLM_CHECK_WITH_INFO(inter_size % (128 / cutlass::sizeof_bits<WeightType>::value) == 0,
        "Inter size does not meet minimum alignment requirements for MOE GEMM");

    // These values must fit into an int for building the source maps
    TLLM_CHECK_WITH_INFO(num_rows <= std::numeric_limits<int>::max(), "Number of rows is too large");
    TLLM_CHECK_WITH_INFO(
        num_rows * num_experts <= std::numeric_limits<int>::max(), "Number of rows * num_experts is too large");
    TLLM_CHECK_WITH_INFO(k * num_experts <= std::numeric_limits<int>::max(), "k * num_experts is too large");

    TLLM_CHECK_WITH_INFO(gemm1_config_.has_value(), "MOE GEMM1 Config is not set");
    TLLM_CHECK_WITH_INFO(gemm2_config_.has_value(), "MOE GEMM2 Config is not set");

    if (int_scales_required)
    {
        TLLM_CHECK_WITH_INFO(
            fc1_int_scales != nullptr, "Weight scales expected but scale for first matmul is a null pointer");
        TLLM_CHECK_WITH_INFO(
            fc2_int_scales != nullptr, "Weight scales expected but scale for second matmul is a null pointer");

        TLLM_CHECK_WITH_INFO(fc1_fp8_dequant == nullptr && fc2_fp8_quant == nullptr && fc2_fp8_dequant == nullptr,
            "FP8 scales are provided for integer quantization");
    }
    else if (fp8_scales_required)
    {
        TLLM_CHECK_WITH_INFO(fc1_expert_biases == nullptr, "Bias is not supported with FP8");
        TLLM_CHECK_WITH_INFO(fc2_expert_biases == nullptr, "Bias is not supported with FP8");

        TLLM_CHECK_WITH_INFO(
            fc1_fp8_dequant != nullptr, "FP8 scales expected but dequant scale for FC1 is a null pointer");
        TLLM_CHECK_WITH_INFO(fc2_fp8_quant != nullptr, "FP8 scales expected but quant scale for FC2 is a null pointer");
        TLLM_CHECK_WITH_INFO(
            fc2_fp8_dequant != nullptr, "FP8 scales expected but quant scale for FC2 is a null pointer");

        TLLM_CHECK_WITH_INFO(
            fc1_int_scales == nullptr && fc2_int_scales == nullptr, "Integer scales are provided for FP8 quantization");
    }
    else if (use_lora && use_fp8)
    {
        TLLM_CHECK_WITH_INFO(
            input_fp8_dequant != nullptr, "FP8 scales expected but quant scale for input is a null pointer");
    }
    else
    {
        TLLM_CHECK_WITH_INFO(
            fc1_int_scales == nullptr, "Scales are ignored for fp32/fp16/bf16 but received weight scale for FC1");
        TLLM_CHECK_WITH_INFO(
            fc2_int_scales == nullptr, "Scales are ignored for fp32/fp16/bf16 but received weight scale for FC2");
        TLLM_CHECK_WITH_INFO(
            fc1_fp8_dequant == nullptr, "Scales are ignored for fp32/fp16/bf16 but received dequant scale for FC1");
        TLLM_CHECK_WITH_INFO(
            fc2_fp8_quant == nullptr, "Scales are ignored for fp32/fp16/bf16 but received quant scale for FC2");
        TLLM_CHECK_WITH_INFO(
            fc2_fp8_dequant == nullptr, "Scales are ignored for fp32/fp16/bf16 but received quant scale for FC2");
    }

    int const num_experts_per_node = num_experts / parallelism_config.ep_size;

    configureWsPtrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k,
        fc1_activation_type, normalization_mode, use_lora);

    int const start_expert = num_experts_per_node * parallelism_config.ep_rank;
    int const end_expert = start_expert + num_experts_per_node;

    genSourceRow(expert_for_source_row, source_rows_, num_rows, k, num_experts, start_expert, end_expert, stream);
    sync_check_cuda_error();

    sortAndScanSoftmaxOutput(expert_for_source_row, source_rows_, permuted_experts_, permuted_rows_,
        expert_first_token_offset_, num_rows, num_experts, num_experts_per_node, k, sorter_,
        static_cast<void*>(sorter_ws_), stream);

    sync_check_cuda_error();

    int64_t const expanded_num_rows = k * num_rows;
    bool is_gated_activation = isGatedActivation(fc1_activation_type);

    if (use_lora)
    {
        std::vector<int>& host_permuted_rows = host_lora_workspace_.host_permuted_rows;
        std::vector<int64_t>& host_expert_first_token_offset = host_lora_workspace_.host_expert_first_token_offset;
        host_permuted_rows.resize(expanded_num_rows);
        TLLM_CUDA_CHECK(cudaMemcpyAsync(host_permuted_rows.data(), permuted_rows_, expanded_num_rows * sizeof(int),
            cudaMemcpyDeviceToHost, stream));
        host_expert_first_token_offset.resize(num_experts_per_node + 1);
        TLLM_CUDA_CHECK(cudaMemcpyAsync(host_expert_first_token_offset.data(), expert_first_token_offset_,
            (num_experts_per_node + 1) * sizeof(int64_t), cudaMemcpyDeviceToHost, stream));
        TLLM_CUDA_CHECK(cudaEventRecord(*(lora_params.memcpy_event_ptr), stream));
    }

    // Actually permute the data
    bool const needs_num_valid = finished || parallelism_config.ep_size > 1;
    int64_t const* num_valid_tokens_ptr = needs_num_valid ? expert_first_token_offset_ + num_experts_per_node : nullptr;
    expandInputRowsKernelLauncher(input_activations, permuted_data_, token_topk_unpermuted_scales, permuted_scales_,
        permuted_rows_, expanded_source_row_to_expanded_dest_row, num_rows, num_valid_tokens_ptr, hidden_size, k,
        stream);

    sync_check_cuda_error();

    Self::gemm1(moe_gemm_runner_, permuted_data_, fc1_result_, glu_inter_result_, expert_first_token_offset_,
        hopper_grouped_gemm_input_, fc1_expert_weights, fc1_expert_biases, num_valid_tokens_ptr, fc1_int_scales,
        fc1_int_zeros, group_size, fc1_fp8_dequant, fc2_fp8_quant, expanded_num_rows, hidden_size, inter_size,
        num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array_, !use_lora, stream, *gemm1_config_);

    sync_check_cuda_error();

    Self::gemm2(moe_gemm_runner_, fc1_result_, fc2_result_, final_output, expert_first_token_offset_,
        hopper_grouped_gemm_input_, fc2_expert_weights, fc2_expert_biases, fc2_int_scales, fc2_int_zeros, group_size,
        fc2_fp8_dequant, token_topk_unpermuted_scales, permuted_scales_, expanded_source_row_to_expanded_dest_row,
        permuted_rows_, expert_for_source_row, num_valid_tokens_ptr, num_rows, expanded_num_rows, hidden_size,
        inter_size, num_experts_per_node, k, !use_deterministic_hopper_reduce_, alpha_scale_ptr_array_, use_lora,
        lora_fc2_result_, stream, parallelism_config, *gemm2_config_);

    sync_check_cuda_error();
}