in maga_transformer/cpp/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h [54:226]
void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales,
const T* weight_zero_points, const T* biases, T* C, int m, int n, int k, const int group_size,
tc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
int* occupancy = nullptr)
{
RTP_LLM_LOG_TRACE(__PRETTY_FUNCTION__);
#ifdef ENABLE_BF16
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|| cutlass::platform::is_same<T, float>::value,
"Specialized for bfloat16, half, float");
#else
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
"Specialized for half, float");
#endif
static_assert(cutlass::platform::is_same<T, WeightType>::value
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
"");
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType_ =
typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
#ifdef ENABLE_BF16
using ElementType =
typename cutlass::platform::conditional<cutlass::platform::is_same<ElementType_, __nv_bfloat16>::value,
cutlass::bfloat16_t, ElementType_>::type;
#else
using ElementType = ElementType_;
#endif
using CutlassWeightType_ =
typename cutlass::platform::conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t,
WeightType>::type;
#ifdef ENABLE_BF16
using CutlassWeightType =
typename cutlass::platform::conditional<cutlass::platform::is_same<CutlassWeightType_, __nv_bfloat16>::value,
cutlass::bfloat16_t, CutlassWeightType_>::type;
#else
using CutlassWeightType = CutlassWeightType_;
#endif
// We need separate config for each architecture since we will target different tensorcore instructions. For float,
// we do not target TCs.
using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
using EpilogueOp = typename tc::Epilogue<ElementType, MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator,
EpilogueTag>::Op;
// useless, just for compiler
using Operator = typename MixedGemmArchTraits::Operator;
using TaggedOperator = typename cutlass::arch::TagOperator<Operator, QuantOp>::TaggedOperator;
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<ElementType, cutlass::layout::RowMajor,
MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType, typename MixedGemmArchTraits::LayoutB,
MixedGemmArchTraits::ElementsPerAccessB, ElementType, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, arch, ThreadblockShape, WarpShape,
typename MixedGemmArchTraits::InstructionShape, EpilogueOp,
typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, Stages, true,
TaggedOperator>::GemmKernel;
using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB<typename GemmKernel_::Mma, typename GemmKernel_::Epilogue,
typename GemmKernel_::ThreadblockSwizzle,
arch, // Ensure top level arch is used for dispatch
GemmKernel_::kSplitKSerial>;
if (occupancy != nullptr)
{
*occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel>();
return;
}
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
const int ldb = cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value
? n
: k * GemmKernel::kInterleave;
if (weight_scales == nullptr)
{
throw std::runtime_error("Weight scales must always be set to a non-null value.");
}
if constexpr (cutlass::isFinegrained(QuantOp))
{
if (group_size != 64 && group_size != 128)
{
throw std::runtime_error("Only group size 64 and 128 supported for fine grained kernels.");
}
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)
{
if (weight_zero_points != nullptr)
{
throw std::runtime_error("Weight zero pointer must be a nullptr for scale only fine grained");
}
}
else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
{
if (weight_zero_points == nullptr)
{
throw std::runtime_error("Weight zero pointer must be valid for scale and bias fine grained");
}
}
}
else
{
if (group_size != k)
{
throw std::runtime_error("Invalid group size for per col'umn scaling kernels.");
}
if (weight_zero_points != nullptr)
{
throw std::runtime_error("Weight zero-points must be null when running per column scaling");
}
}
const int ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0;
ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f);
typename Gemm::Arguments args({m, n, k}, group_size, {reinterpret_cast<ElementType*>(const_cast<T*>(A)), k},
{reinterpret_cast<CutlassWeightType*>(const_cast<WeightType*>(B)), ldb},
{reinterpret_cast<ElementType*>(const_cast<T*>(weight_scales)), ld_scale_zero},
{reinterpret_cast<ElementType*>(const_cast<T*>(weight_zero_points)), ld_scale_zero},
{reinterpret_cast<ElementType*>(const_cast<T*>(biases)), 0}, {reinterpret_cast<ElementType*>(C), n},
gemm_config.split_k_factor, {ElementAccumulator(1.f), output_op_beta});
// This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of
// threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the
// interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write
// our own predicated iterator in order to relax this limitation.
if (GemmKernel::kInterleave > 1
&& ((k % MixedGemmArchTraits::ThreadblockK)
|| ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK)))
{
throw std::runtime_error("Temp assertion: k must be multiple of threadblockK");
}
Gemm gemm;
if (gemm.get_workspace_size(args) > workspace_bytes)
{
RTP_LLM_LOG_WARNING(
"Requested split-k but workspace size insufficient. Falling back to non-split-k implementation.");
// If requested split-k factor will require more workspace bytes, revert to standard gemm.
args.batch_count = 1;
}
auto can_implement = gemm.can_implement(args);
if (can_implement != cutlass::Status::kSuccess)
{
std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: "
+ std::string(cutlassGetStatusString(can_implement));
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg);
}
auto init_status = gemm.initialize(args, workspace, stream);
if (init_status != cutlass::Status::kSuccess)
{
std::string err_msg
= "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status));
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg);
}
auto run_status = gemm.run(stream);
if (run_status != cutlass::Status::kSuccess)
{
std::string err_msg
= "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status));
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg);
}
}