maga_transformer/cpp/cutlass/cutlass_kernels/cutlass_heuristic.cc (553 lines of code) (raw):

/* * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "maga_transformer/cpp/cutlass/cutlass_kernels/cutlass_heuristic.h" #ifndef _WIN32 #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif // #ifndef _WIN32 #include "cutlass/gemm/gemm.h" #include "cutlass/numeric_types.h" #include "maga_transformer/cpp/cutlass/cutlass_kernels/gemm_lut_utils.h" #include "maga_transformer/cpp/cuda/trt_utils.h" #ifndef _WIN32 #pragma GCC diagnostic pop #endif // #ifndef _WIN32 #include <cuda_runtime_api.h> #include <vector> #include <set> using namespace tensorrt_llm::cutlass_extensions; namespace tensorrt_llm { namespace kernels { namespace cutlass_kernels { TileConfig get_tile_config_from_config(CutlassTileConfig tile_config) { switch (tile_config) { case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: return TileConfig{16, 128, 64, 16, 32, 64}; case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: return TileConfig{16, 256, 64, 16, 64, 64}; case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileConfig{32, 128, 64, 32, 32, 64}; case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: return TileConfig{64, 64, 128, 32, 64, 64}; case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: return TileConfig{64, 128, 64, 32, 64, 64}; case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: return TileConfig{64, 128, 64, 64, 32, 64}; case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: return TileConfig{128, 64, 64, 64, 32, 64}; case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: return TileConfig{128, 128, 8, 64, 64, 8}; case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: return TileConfig{128, 128, 64, 64, 32, 64}; case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: return TileConfig{128, 128, 64, 64, 64, 64}; case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: return TileConfig{128, 128, 64, 128, 32, 64}; case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileConfig{128, 256, 64, 64, 64, 64}; case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: return TileConfig{256, 128, 64, 64, 64, 64}; default: throw std::runtime_error("[TensorRT-LLm Error][get_grid_shape_for_config] Invalid config: " + std::to_string(int(tile_config))); } } struct TileShape { int m; int n; }; TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { switch (tile_config) { case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: return TileShape{16, 128}; case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: return TileShape{16, 256}; case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileShape{32, 128}; case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: return TileShape{64, 64}; case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: return TileShape{64, 128}; case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: return TileShape{128, 64}; case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: return TileShape{128, 128}; case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileShape{128, 256}; case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: return TileShape{256, 128}; case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: return TileShape{16, 256}; default: TLLM_THROW("[get_grid_shape_for_config] Invalid config"); } } bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const CutlassGemmConfig gemm_config, const size_t workspace_bytes, const bool is_weight_only) { int split_k_factor = gemm_config.split_k_factor; // All tile sizes have a k_tile of 64. static constexpr int k_tile = 64; TileConfig tile_config = get_tile_config_from_config(gemm_config.tile_config); // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k if (is_weight_only) { if ((k % k_tile) != 0) { return false; } if ((k % split_k_factor) != 0) { return false; } const int k_elements_per_split = k / split_k_factor; if ((k_elements_per_split % k_tile) != 0) { return false; } } // Check that the workspace has sufficient space for this split-k factor const int ctas_in_m_dim = (m + tile_config.block_m - 1) / tile_config.block_m; const int ctas_in_n_dim = (n + tile_config.block_n - 1) / tile_config.block_n; const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; if (required_ws_bytes > int(workspace_bytes)) { return false; } return true; } bool is_valid_split_k_factor(int64_t const m, int64_t const n, int64_t const k, TileShape const tile_shape, int const split_k_factor, size_t const workspace_bytes, bool const is_weight_only) { // All tile sizes have a k_tile of 64. static constexpr int k_tile = 64; // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k if (is_weight_only) { if ((k % k_tile) != 0) { return false; } if ((k % split_k_factor) != 0) { return false; } int const k_elements_per_split = k / split_k_factor; if ((k_elements_per_split % k_tile) != 0) { return false; } } // Check that the workspace has sufficient space for this split-k factor int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; if (required_ws_bytes > workspace_bytes) { return false; } return true; } std::vector<CutlassTileConfig> get_candidate_tiles( const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only) { enum class CutlassGemmType : char { Default, WeightOnly, Simt, Int8 }; CutlassGemmType gemm_type = CutlassGemmType::Default; if (simt_configs_only) { gemm_type = CutlassGemmType::Simt; } else if (is_weight_only) { gemm_type = CutlassGemmType::WeightOnly; } else if (int8_configs_only) { gemm_type = CutlassGemmType::Int8; } std::vector<CutlassTileConfig> base_configs{ CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; if (sm >= 75) { base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); } switch (gemm_type) { case CutlassGemmType::Simt: return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; case CutlassGemmType::WeightOnly: if (sm >= 75) { return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; } else { return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; } case CutlassGemmType::Int8: return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; default: return base_configs; } } std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only, const int max_split_k) { std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only); std::vector<CutlassGemmConfig> candidate_configs; const int min_stages = int8_configs_only ? (sm >= 80 ? 3 : 2) : 2; const int max_stages = int8_configs_only ? (sm >= 80 ? 6 : 2) : (sm >= 80 ? 4 : 2); for (const auto& tile_config : tiles) { for (int stages = min_stages; stages <= max_stages; ++stages) { CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; candidate_configs.push_back(config); if (sm >= 70) { for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) { auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; candidate_configs.push_back(config); } } } } return candidate_configs; } std::vector<CutlassGemmConfig> get_valid_config_from_occupancies( const std::vector<CutlassGemmConfig>& candidate_configs, const std::vector<int>& occupancies) { if (occupancies.size() != candidate_configs.size()) { throw std::runtime_error( "[TensorRT-LLm Error][estimate_best_config_from_occupancies] occpancies and " "candidate configs vectors must have equal length."); } std::vector<CutlassGemmConfig> valid_configs; for (int ii = 0; ii < int(candidate_configs.size()); ++ii) { CutlassGemmConfig candidate_config = candidate_configs[ii]; int occupancy = occupancies[ii]; if (occupancy == 0) { continue; } valid_configs.push_back(candidate_config); } return valid_configs; } CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs, const std::vector<int>& occupancies, const int64_t m, const int64_t n, const int64_t k, const int multi_processor_count) { if (occupancies.size() != candidate_configs.size()) { throw std::runtime_error( "[TensorRT-LLm Error][estimate_best_config_from_occupancies] occpancies and " "candidate configs vectors must have equal length."); } CutlassGemmConfig best_config; // Score will be [0, 1]. The objective is to minimize this score. // It represents the fraction of SM resources unused in the last wave. float config_score = 1.0f; int config_waves = INT_MAX; int current_m_tile = 0; for (int ii = 0; ii < int(candidate_configs.size()); ++ii) { CutlassGemmConfig candidate_config = candidate_configs[ii]; TileConfig tile_config = get_tile_config_from_config(candidate_config.tile_config); int occupancy = occupancies[ii]; if (occupancy == 0) { continue; } // Keep small tile sizes when possible. if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile && current_m_tile < tile_config.block_m) { continue; } const int ctas_in_m_dim = (m + tile_config.block_m - 1) / tile_config.block_m; const int ctas_in_n_dim = (n + tile_config.block_n - 1) / tile_config.block_n; const int ctas_per_wave = occupancy * multi_processor_count; const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * candidate_config.split_k_factor; const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); const float current_score = float(num_waves_total) - num_waves_fractional; const float score_slack = 0.1f; if (current_score < config_score || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { config_score = current_score; config_waves = num_waves_total; best_config = candidate_config; current_m_tile = tile_config.block_m; } else if (current_score == config_score && (best_config.stages < candidate_config.stages || candidate_config.split_k_factor < best_config.split_k_factor || current_m_tile < tile_config.block_m)) { // Prefer deeper pipeline or smaller split-k best_config = candidate_config; current_m_tile = tile_config.block_m; config_waves = num_waves_total; } } if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { throw std::runtime_error("[TensorRT-LLm Error] Heurisitc failed to find a valid config."); } return best_config; } std::vector<CutlassTileConfig> get_candidate_tiles( int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { enum class CutlassGemmType : char { Default, WeightOnly, Simt, Int8, Fp8 }; CutlassGemmType gemm_type = CutlassGemmType::Default; if (config_type_param & CutlassGemmConfig::SIMT_ONLY) { gemm_type = CutlassGemmType::Simt; } else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) { gemm_type = CutlassGemmType::WeightOnly; } else if (config_type_param & CutlassGemmConfig::INT8_ONLY) { gemm_type = CutlassGemmType::Int8; } else if (config_type_param & CutlassGemmConfig::FP8_ONLY) { gemm_type = CutlassGemmType::Fp8; } std::vector<CutlassTileConfig> base_configs{ CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; if (sm >= 75) { base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); } switch (gemm_type) { case CutlassGemmType::Simt: return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; case CutlassGemmType::WeightOnly: if (sm >= 75) { return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; } else { return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; } case CutlassGemmType::Int8: return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; case CutlassGemmType::Fp8: if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) { if (sm == 89) { return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; } else { // no valid ampere style fp8 configs for sm90 return {}; } } default: return base_configs; } } std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90( int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config) { #ifdef FAST_BUILD // Fast build disables all configs except this one for SM90 return {CutlassTileConfigSM90::CtaShape128x128x128B}; #else if (config & CutlassGemmConfig::GROUPED_GEMM) { return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; } else { return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B, CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B, CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B}; } #endif } // We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve // compilation speed. bool supports_mcast_along_m(CutlassTileConfigSM90 const tile) { std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; return valid_tiles.count(tile) == 1; } // We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve // compilation speed. bool supports_mcast_along_n(CutlassTileConfigSM90 const tile) { std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B, CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; return valid_tiles.count(tile) == 1; } std::vector<CutlassGemmConfig> get_candidate_configs( int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER)) { std::vector<CutlassTileConfigSM90> tiles = get_candidate_tiles_sm90(sm, config_type_param); std::vector<CutlassGemmConfig> candidate_configs; for (auto const& tile_config : tiles) { CutlassGemmConfig config( tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); candidate_configs.push_back(config); bool const has_m_mcast = supports_mcast_along_m(tile_config); bool const has_n_mcast = supports_mcast_along_n(tile_config); if (has_m_mcast) { CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1); candidate_configs.push_back(config); } if (has_n_mcast) { CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1); candidate_configs.push_back(config); } if (has_m_mcast && has_n_mcast) { CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x2x1); candidate_configs.push_back(config); } } return candidate_configs; } std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param); std::vector<CutlassGemmConfig> candidate_configs; bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY; int const min_stages = int8_configs_only ? 3 : 2; int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); for (auto const& tile_config : tiles) { for (int stages = min_stages; stages <= max_stages; ++stages) { CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages); candidate_configs.push_back(config); if (sm >= 75) { for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) { auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; candidate_configs.push_back(config); } } } } return candidate_configs; } CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmConfig> const& candidate_configs, std::vector<int> const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts, int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only) { if (occupancies.size() != candidate_configs.size()) { TLLM_THROW( "[estimate_best_config_from_occupancies] occpancies and " "candidate configs vectors must have equal length."); } CutlassGemmConfig best_config; // Score will be [0, 1]. The objective is to minimize this score. // It represents the fraction of SM resources unused in the last wave. float config_score = 1.0f; int config_waves = INT_MAX; int current_m_tile = 0; int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; for (int ii = 0; ii < candidate_configs.size(); ++ii) { CutlassGemmConfig candidate_config = candidate_configs[ii]; TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); int occupancy = occupancies[ii]; if (occupancy == 0) { continue; } // Keep small tile sizes when possible. if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile && current_m_tile < tile_shape.m) { continue; } int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { int const ctas_per_wave = occupancy * multi_processor_count; int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave); float const current_score = float(num_waves_total) - num_waves_fractional; float const score_slack = 0.1f; if (current_score < config_score || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { config_score = current_score; config_waves = num_waves_total; SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; best_config = CutlassGemmConfig( candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages); current_m_tile = tile_shape.m; } else if (current_score == config_score && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor || current_m_tile < tile_shape.m)) { // Prefer deeper pipeline or smaller split-k SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; best_config = CutlassGemmConfig( candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages); current_m_tile = tile_shape.m; config_waves = num_waves_total; } } } } if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { TLLM_THROW("Heurisitc failed to find a valid config."); } return best_config; } } // namespace cutlass_kernels } // namespace kernels } // namespace tensorrt_llm