maga_transformer/cpp/cutlass/gen.py (1,104 lines of code) (raw):

import enum import os from itertools import product ################################################################################ # Epilogue Tag enum and string utils ################################################################################################# # # Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# """ Data types and tags used for emitting CUTLASS C++ kernels """ import enum import re # The following block implements enum.auto() for Python 3.5 variants that don't include it such # as the default 3.5.2 on Ubuntu 16.04. # # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility try: from enum import auto as enum_auto except ImportError: __cutlass_library_auto_enum = 0 def enum_auto() -> int: global __cutlass_library_auto_enum i = __cutlass_library_auto_enum __cutlass_library_auto_enum += 1 return i ################################################################################################### # class GeneratorTarget(enum.Enum): Library = enum_auto() # GeneratorTargetNames = { GeneratorTarget.Library: 'library' } # ################################################################################################### # class DataType(enum.Enum): void = enum_auto() # primarily used to disable C tensor for epilogues b1 = enum_auto() u4 = enum_auto() u8 = enum_auto() u16 = enum_auto() u32 = enum_auto() u64 = enum_auto() s4 = enum_auto() s8 = enum_auto() s16 = enum_auto() s32 = enum_auto() s64 = enum_auto() e4m3 = enum_auto() e5m2 = enum_auto() f16 = enum_auto() bf16 = enum_auto() f32 = enum_auto() tf32 = enum_auto() f64 = enum_auto() cf16 = enum_auto() cbf16 = enum_auto() cf32 = enum_auto() ctf32 = enum_auto() cf64 = enum_auto() cs4 = enum_auto() cs8 = enum_auto() cs16 = enum_auto() cs32 = enum_auto() cs64 = enum_auto() cu4 = enum_auto() cu8 = enum_auto() cu16 = enum_auto() cu32 = enum_auto() cu64 = enum_auto() invalid = enum_auto() # ShortDataTypeNames = { DataType.s32: 'i', DataType.e4m3: 'e4m3', DataType.e5m2: 'e5m2', DataType.f16: 'h', DataType.f32: 's', DataType.f64: 'd', DataType.cf32: 'c', DataType.cf64: 'z', } # DataTypeNames = { DataType.void: "void", DataType.b1: "b1", DataType.u4: "u4", DataType.u8: "u8", DataType.u16: "u16", DataType.u32: "u32", DataType.u64: "u64", DataType.s4: "s4", DataType.s8: "s8", DataType.s16: "s16", DataType.s32: "s32", DataType.s64: "s64", DataType.e4m3: 'e4m3', DataType.e5m2: 'e5m2', DataType.f16: "f16", DataType.bf16: "bf16", DataType.f32: "f32", DataType.tf32: "tf32", DataType.f64: "f64", DataType.cf16: "cf16", DataType.cbf16: "cbf16", DataType.cf32: "cf32", DataType.ctf32: "ctf32", DataType.cf64: "cf64", DataType.cu4: "cu4", DataType.cu8: "cu8", DataType.cu16: "cu16", DataType.cu32: "cu32", DataType.cu64: "cu64", DataType.cs4: "cs4", DataType.cs8: "cs8", DataType.cs16: "cs16", DataType.cs32: "cs32", DataType.cs64: "cs64", } DataTypeTag = { DataType.void: "void", DataType.b1: "cutlass::uint1b_t", DataType.u4: "cutlass::uint4b_t", DataType.u8: "uint8_t", DataType.u16: "uint16_t", DataType.u32: "uint32_t", DataType.u64: "uint64_t", DataType.s4: "cutlass::int4b_t", DataType.s8: "int8_t", DataType.s16: "int16_t", DataType.s32: "int32_t", DataType.s64: "int64_t", DataType.e4m3: 'cutlass::float_e4m3_t', DataType.e5m2: 'cutlass::float_e5m2_t', DataType.f16: "cutlass::half_t", DataType.bf16: "cutlass::bfloat16_t", DataType.f32: "float", DataType.tf32: "cutlass::tfloat32_t", DataType.f64: "double", DataType.cf16: "cutlass::complex<cutlass::half_t>", DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>", DataType.cf32: "cutlass::complex<float>", DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>", DataType.cf64: "cutlass::complex<double>", DataType.cu4: "cutlass::complex<cutlass::uint4b_t>", DataType.cu8: "cutlass::complex<cutlass::uint8_t>", DataType.cu16: "cutlass::complex<cutlass::uint16_t>", DataType.cu32: "cutlass::complex<cutlass::uint32_t>", DataType.cu64: "cutlass::complex<cutlass::uint64_t>", DataType.cs4: "cutlass::complex<cutlass::int4b_t>", DataType.cs8: "cutlass::complex<cutlass::int8_t>", DataType.cs16: "cutlass::complex<cutlass::int16_t>", DataType.cs32: "cutlass::complex<cutlass::int32_t>", DataType.cs64: "cutlass::complex<cutlass::int64_t>", } DataTypeSize = { DataType.void: 0, DataType.b1: 1, DataType.u4: 4, DataType.u8: 8, DataType.u16: 16, DataType.u32: 32, DataType.u64: 64, DataType.s4: 4, DataType.s8: 8, DataType.s16: 16, DataType.s32: 32, DataType.s64: 64, DataType.e4m3: 8, DataType.e5m2: 8, DataType.f16: 16, DataType.bf16: 16, DataType.f32: 32, DataType.tf32: 32, DataType.f64: 64, DataType.cf16: 32, DataType.cbf16: 32, DataType.cf32: 64, DataType.ctf32: 32, DataType.cf64: 128, DataType.cu4: 8, DataType.cu8: 16, DataType.cu16: 32, DataType.cu32: 64, DataType.cu64: 128, DataType.cs4: 8, DataType.cs8: 16, DataType.cs16: 32, DataType.cs32: 64, DataType.cs64: 128, } ################################################################################################### # class BlasMode(enum.Enum): symmetric = enum_auto() hermitian = enum_auto() # BlasModeTag = { BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric', BlasMode.hermitian: 'cutlass::BlasMode::kHermitian', } # class ComplexTransform(enum.Enum): none = enum_auto() conj = enum_auto() # ComplexTransformTag = { ComplexTransform.none: 'cutlass::ComplexTransform::kNone', ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', } # Used for cutlass3x complex kernel collective mainloop builder instantiation ComplexTransformTag3x = { ComplexTransform.none: 'cute::identity', ComplexTransform.conj: 'cute::conjugate', } # RealComplexBijection = [ (DataType.f16, DataType.cf16), (DataType.f32, DataType.cf32), (DataType.f64, DataType.cf64), ] # def is_complex(data_type): for r, c in RealComplexBijection: if data_type == c: return True return False # def get_complex_from_real(real_type): for r, c in RealComplexBijection: if real_type == r: return c return DataType.invalid # def get_real_from_complex(complex_type): for r, c in RealComplexBijection: if complex_type == c: return r return DataType.invalid # class ComplexMultiplyOp(enum.Enum): multiply_add = enum_auto() gaussian = enum_auto() ################################################################################################### # class MathOperation(enum.Enum): multiply_add = enum_auto() multiply_add_saturate = enum_auto() multiply_add_mixed_input_upcast = enum_auto() xor_popc = enum_auto() and_popc = enum_auto() multiply_add_fast_bf16 = enum_auto() multiply_add_fast_f16 = enum_auto() multiply_add_fast_f32 = enum_auto() multiply_add_complex_fast_f32 = enum_auto() multiply_add_complex = enum_auto() multiply_add_complex_gaussian = enum_auto() multiply_add_fast_accum = enum_auto() # MathOperationTag = { MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast', MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', MathOperation.and_popc: 'cutlass::arch::OpAndPopc', MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', MathOperation.multiply_add_fast_accum: 'cutlass::arch::OpMultiplyAddFastAccum', } ################################################################################################### # class LayoutType(enum.Enum): ColumnMajor = enum_auto() RowMajor = enum_auto() ColumnMajorInterleaved2 = enum_auto() RowMajorInterleaved2 = enum_auto() ColumnMajorInterleaved32 = enum_auto() RowMajorInterleaved32 = enum_auto() ColumnMajorInterleaved64 = enum_auto() RowMajorInterleaved64 = enum_auto() TensorNWC = enum_auto() TensorNHWC = enum_auto() TensorNDHWC = enum_auto() TensorNCHW = enum_auto() TensorNGHWC = enum_auto() TensorNC32HW32 = enum_auto() TensorNC64HW64 = enum_auto() TensorC32RSK32 = enum_auto() TensorC64RSK64 = enum_auto() TensorKCS = enum_auto() TensorKCSR = enum_auto() TensorKCSRT = enum_auto() # LayoutTag = { LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor', LayoutType.RowMajor: 'cutlass::layout::RowMajor', LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', LayoutType.TensorNWC: 'cutlass::layout::TensorNWC', LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', LayoutType.TensorKCS: 'cutlass::layout::TensorKCS', LayoutType.TensorKCSR: 'cutlass::layout::TensorKCSR', LayoutType.TensorKCSRT: 'cutlass::layout::TensorKCSRT' } # TransposedLayout = { LayoutType.ColumnMajor: LayoutType.RowMajor, LayoutType.RowMajor: LayoutType.ColumnMajor, LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, LayoutType.TensorNHWC: LayoutType.TensorNHWC } # ShortLayoutTypeNames = { LayoutType.ColumnMajor: 'n', LayoutType.ColumnMajorInterleaved2: 'n2', LayoutType.ColumnMajorInterleaved32: 'n32', LayoutType.ColumnMajorInterleaved64: 'n64', LayoutType.RowMajor: 't', LayoutType.RowMajorInterleaved2: 't2', LayoutType.RowMajorInterleaved32: 't32', LayoutType.RowMajorInterleaved64: 't64', LayoutType.TensorNWC: 'nwc', LayoutType.TensorNHWC: 'nhwc', LayoutType.TensorNDHWC: 'ndhwc', LayoutType.TensorNCHW: 'nchw', LayoutType.TensorNGHWC: 'nghwc', LayoutType.TensorNC32HW32: 'nc32hw32', LayoutType.TensorNC64HW64: 'nc64hw64', LayoutType.TensorC32RSK32: 'c32rsk32', LayoutType.TensorC64RSK64: 'c64rsk64', LayoutType.TensorKCS: 'kcs', LayoutType.TensorKCSR: 'kcsr', LayoutType.TensorKCSRT: 'kcsrt' } # ShortComplexLayoutNames = { (LayoutType.ColumnMajor, ComplexTransform.none): 'n', (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', (LayoutType.RowMajor, ComplexTransform.none): 't', (LayoutType.RowMajor, ComplexTransform.conj): 'h' } ################################################################################################### class KernelScheduleType(enum.Enum): ScheduleAuto = enum_auto() Multistage = enum_auto() CpAsyncWarpSpecialized = enum_auto() CpAsyncWarpSpecializedPingpong = enum_auto() CpAsyncWarpSpecializedCooperative = enum_auto() Tma = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedPingpong = enum_auto() TmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecializedFP8FastAccum = enum_auto() TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() ImplicitTmaWarpSpecializedSm90 = enum_auto() # KernelScheduleTag = { KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage', KernelScheduleType.CpAsyncWarpSpecialized: 'cutlass::gemm::KernelCpAsyncWarpSpecialized', KernelScheduleType.CpAsyncWarpSpecializedPingpong: 'cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong', KernelScheduleType.CpAsyncWarpSpecializedCooperative: 'cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative', KernelScheduleType.Tma: 'cutlass::gemm::KernelTma', KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized', KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong', KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative', KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum', KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum', KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90', } # KernelScheduleSuffixes = { KernelScheduleType.ScheduleAuto: '', KernelScheduleType.Multistage: '_cpasync', KernelScheduleType.CpAsyncWarpSpecialized: '_cpasync_warpspecialized', KernelScheduleType.CpAsyncWarpSpecializedPingpong: '_cpasync_warpspecialized_pingpong', KernelScheduleType.CpAsyncWarpSpecializedCooperative: '_cpasync_warpspecialized_cooperative', KernelScheduleType.Tma: '_unspecialized', KernelScheduleType.TmaWarpSpecialized: '_warpspecialized', KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong', KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative', KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum', KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized', } class EpilogueScheduleType(enum.Enum): ScheduleAuto = enum_auto() EpilogueTransposed = enum_auto() NoSmemWarpSpecialized = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedCooperative = enum_auto() # EpilogueScheduleTag = { EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto', EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed', EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized', EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', } # EpilogueScheduleSuffixes = { EpilogueScheduleType.ScheduleAuto: '', EpilogueScheduleType.EpilogueTransposed: '', EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem', EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', } class EpilogueFunctor3x(enum.Enum): LinearCombination = enum_auto() # EpilogueFunctor3xTag = { EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination', } class TileSchedulerType(enum.Enum): Default = enum_auto() Persistent = enum_auto() StreamK = enum_auto() # TileSchedulerTag = { TileSchedulerType.Default: 'void', TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler', TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler', } # TileSchedulerSuffixes = { TileSchedulerType.Default: '', TileSchedulerType.Persistent: '', TileSchedulerType.StreamK: '_stream_k', } ################################################################################################### # class SideMode(enum.Enum): Left = enum_auto() Right = enum_auto() # SideModeTag = { SideMode.Left: 'cutlass::SideMode::kLeft', SideMode.Right: 'cutlass::SideMode::kRight' } # ShortSideModeNames = { SideMode.Left: 'ls', SideMode.Right: 'rs' } ################################################################################################### # class FillMode(enum.Enum): Lower = enum_auto() Upper = enum_auto() # FillModeTag = { FillMode.Lower: 'cutlass::FillMode::kLower', FillMode.Upper: 'cutlass::FillMode::kUpper' } # ShortFillModeNames = { FillMode.Lower: 'l', FillMode.Upper: 'u' } ################################################################################################### # class DiagType(enum.Enum): NonUnit = enum_auto() Unit = enum_auto() # DiagTypeTag = { DiagType.NonUnit: 'cutlass::DiagType::kNonUnit', DiagType.Unit: 'cutlass::DiagType::kUnit' } # ShortDiagTypeNames = { DiagType.NonUnit: 'nu', DiagType.Unit: 'un' } ################################################################################################### # class OpcodeClass(enum.Enum): Simt = enum_auto() TensorOp = enum_auto() WmmaTensorOp = enum_auto() SparseTensorOp = enum_auto() OpcodeClassNames = { OpcodeClass.Simt: 'simt', OpcodeClass.TensorOp: 'tensorop', OpcodeClass.WmmaTensorOp: 'wmma_tensorop', } OpcodeClassTag = { OpcodeClass.Simt: 'cutlass::arch::OpClassSimt', OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', } ################################################################################################### # class OperationKind(enum.Enum): Gemm = enum_auto() RankK = enum_auto() Rank2K = enum_auto() Trmm = enum_auto() Symm = enum_auto() Conv2d = enum_auto() Conv3d = enum_auto() # OperationKindNames = { OperationKind.Gemm: 'gemm' , OperationKind.RankK: 'rank_k' , OperationKind.Rank2K: 'rank_2k' , OperationKind.Trmm: 'trmm' , OperationKind.Symm: 'symm' , OperationKind.Conv2d: 'conv2d' , OperationKind.Conv3d: 'conv3d' } # class Target(enum.Enum): library = enum_auto() # ArchitectureNames = { 50: 'maxwell', 60: 'pascal', 61: 'pascal', 70: 'volta', 75: 'turing', 80: 'ampere', 89: 'ada', 90: 'hopper' } # SharedMemPerCC = { 70: 96, # 96KB of SMEM 72: 96, # 96KB of SMEM 75: 64, # 64KB of SMEM 80: 163, # 163KB of SMEM - 1KB reserved for the driver 86: 99, # 99KB of SMEM - 1KB reserved for the driver 87: 163, # 163KB of SMEM - 1KB reserved for the driver 89: 99, # 99KB of SMEM - 1KB reserved for the driver 90: 227, # 227KB of SMEM - 1KB reserved for the driver } ################################################################################################### # def SubstituteTemplate(template, values): text = template changed = True while changed: changed = False for key, value in values.items(): regex = "\\$\\{%s\\}" % key newtext = re.sub(regex, value, text) if newtext != text: changed = True text = newtext return text ################################################################################################### # class GemmKind(enum.Enum): Gemm = enum_auto() Sparse = enum_auto() Universal = enum_auto() Universal3x = enum_auto() SparseUniversal3x = enum_auto() PlanarComplex = enum_auto() PlanarComplexArray = enum_auto() Grouped = enum_auto() # GemmKindNames = { GemmKind.Gemm: "gemm", GemmKind.Sparse: "spgemm", GemmKind.Universal: "gemm", GemmKind.Universal3x: "gemm", GemmKind.SparseUniversal3x: "spgemm", GemmKind.PlanarComplex: "gemm_planar_complex", GemmKind.PlanarComplexArray: "gemm_planar_complex_array", GemmKind.Grouped: "gemm_grouped", } # class RankKKind(enum.Enum): Universal = enum_auto() # RankKKindNames = { RankKKind.Universal: "rank_k" } # class TrmmKind(enum.Enum): Universal = enum_auto() # TrmmKindNames = { TrmmKind.Universal: "trmm" } # class SymmKind(enum.Enum): Universal = enum_auto() # SymmKindNames = { SymmKind.Universal: "symm" } # class EpilogueFunctor(enum.Enum): LinearCombination = enum_auto() LinearCombinationClamp = enum_auto() # EpilogueFunctorTag = { EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', } # class SwizzlingFunctor(enum.Enum): Identity1 = enum_auto() Identity2 = enum_auto() Identity4 = enum_auto() Identity8 = enum_auto() Horizontal = enum_auto() StridedDgradIdentity1 = enum_auto() StridedDgradIdentity4 = enum_auto() StridedDgradHorizontal = enum_auto() StreamK = enum_auto() # SwizzlingFunctorTag = { SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK', } # class GroupScheduleMode(enum.Enum): Device = enum_auto(), Host = enum_auto() # GroupScheduleModeTag = { GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' } # ShortGroupScheduleModeNames = { GroupScheduleMode.Device: 'Device', GroupScheduleMode.Host: 'Host' } ################################################################################################### # class ConvKind(enum.IntEnum): Fprop = 0 Dgrad = 1 Wgrad = 2 # ConvKindTag = { ConvKind.Fprop: 'cutlass::conv::Operator::kFprop', ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad', ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad' } ConvKindNames = { ConvKind.Fprop: 'fprop', ConvKind.Dgrad: 'dgrad', ConvKind.Wgrad: 'wgrad', } class ConvMode(enum.IntEnum): CrossCorrelation = 0 Convolution = 1 # class IteratorAlgorithm(enum.Enum): Analytic = 0 Optimized = 1 FixedChannels = 2 FewChannels = 3 FixedStrideDilation = 4 # IteratorAlgorithmTag = { IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels', IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation' } IteratorAlgorithmNames = { IteratorAlgorithm.Analytic: 'analytic', IteratorAlgorithm.Optimized: 'optimized', IteratorAlgorithm.FixedChannels: 'fixed_channels', IteratorAlgorithm.FewChannels: 'few_channels', IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation' } # class StrideSupport(enum.Enum): Strided = 0 Unity = 1 Fixed = 2 # StrideSupportTag = { StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed' } StrideSupportNames = { StrideSupport.Strided: '', StrideSupport.Unity: 'unity_stride', StrideSupport.Fixed: 'fixed_stride' } # class GroupMode(enum.Enum): NoneGroup = enum_auto() # dense conv (G=1) SingleGroup = enum_auto() # grouped convolution (single group per CTA) MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) # GroupModeTag = { GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone', GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup', GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup', GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise', } GroupModeNames = { GroupMode.NoneGroup: '', GroupMode.SingleGroup: 'single_group', GroupMode.MultipleGroup: 'multiple_group', GroupMode.Depthwise: 'depthwise', } ################################################################################################### # class MathInstruction: def __init__(self, instruction_shape, \ element_a, element_b, element_accumulator, \ opcode_class, math_operation = MathOperation.multiply_add \ ): self.instruction_shape = instruction_shape self.element_a = element_a self.element_b = element_b self.element_accumulator = element_accumulator self.opcode_class = opcode_class self.math_operation = math_operation # class TileDescription: def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1]): self.threadblock_shape = threadblock_shape self.tile_shape = threadblock_shape self.stages = stages self.warp_count = warp_count self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute self.cluster_shape = cluster_shape def procedural_name(self): if self.minimum_compute_capability >= 90: return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format( tbm = self.threadblock_shape[0], tbn = self.threadblock_shape[1], tbk = self.threadblock_shape[2], cm = self.cluster_shape[0], cn = self.cluster_shape[1], ck = self.cluster_shape[2], s = self.stages) else: return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) # class Direct2dConvFixedStrideDilationTileDescription: def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] self.threadblock_output_shape = threadblock_output_shape self.filter_shape = filter_shape self.stages = stages self.warp_count = warp_count self.stride = stride self.dilation = dilation self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute def procedural_name(self): str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.threadblock_output_shape[0], self.threadblock_output_shape[1], self.threadblock_output_shape[2], self.threadblock_output_shape[3], self.stages, self.filter_shape[0], self.filter_shape[1]) # Fixed Strided and dilation if self.stride != [-1, -1] and self.dilation != [-1, -1]: str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], self.stride[1], self.dilation[0], self.dilation[1]) return str_name # class Direct2dConvFixedStrideDilationTileDescription: def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] self.threadblock_output_shape = threadblock_output_shape self.filter_shape = filter_shape self.stages = stages self.warp_count = warp_count self.stride = stride self.dilation = dilation self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute def procedural_name(self): str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.threadblock_output_shape[0], self.threadblock_output_shape[1], self.threadblock_output_shape[2], self.threadblock_output_shape[3], self.stages, self.filter_shape[0], self.filter_shape[1]) # Fixed Strided and dilation if self.stride != [-1, -1] and self.dilation != [-1, -1]: str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], self.stride[1], self.dilation[0], self.dilation[1]) return str_name # class TensorDescription: def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): self.element = element self.layout = layout self.alignment = alignment self.complex_transform = complex_transform # class SymmetricTensorDescription: def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left): self.element = element self.layout = layout self.fill_mode = fill_mode self.alignment = alignment self.complex_transform = complex_transform self.side_mode = side_mode # class TriangularTensorDescription: def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none): self.element = element self.layout = layout self.side_mode = side_mode self.fill_mode = fill_mode self.diag_type = diag_type self.alignment = alignment self.complex_transform = complex_transform # def CalculateSmemUsage(operation): cta_shape = operation.tile_description.threadblock_shape stages = operation.tile_description.stages if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) if DataTypeSize[operation.A.element] == 32: elements_per_8b_md = 2 elif DataTypeSize[operation.A.element] == 4: elements_per_8b_md = 8 else: elements_per_8b_md = 4 smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md else: # Few BLAS3 operations only have A tensor data_type_size_a = DataTypeSize[operation.A.element] data_type_size_b = DataTypeSize[operation.A.element] if operation.is_mixed_input(): data_type_size_b = DataTypeSize[operation.B.element] smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \ data_type_size_b * cta_shape[1] * cta_shape[2] // 8 smem_usage = smem_per_stage * stages return (smem_usage >> 10) class GemmUniversalMode(enum.IntEnum): """ Types corresponding to GemmUniversalMode """ Gemm = 0 GemmSplitKParallel = 1 Batched = 2 Array = 3 class SplitKMode(enum.IntEnum): """ Types corresponding to SplitKMode """ NoneSplitK = 0 Serial = 1 Parallel = 2 class TrtLlm_EpilogueTag(enum.Enum): epilogue_op_default = enum.auto() epilogue_op_bias = enum.auto() epilogue_op_silu = enum.auto() epilogue_op_gelu = enum.auto() class TrtLlm_EpilogueFusion(enum.Enum): epilogue_fusion_none = enum.auto() epilogue_fusion_finalize = enum.auto() EpiTagNames = { TrtLlm_EpilogueTag.epilogue_op_default: "lc", # linear combination TrtLlm_EpilogueTag.epilogue_op_bias: "lc_bias", # linear combination with bias addition TrtLlm_EpilogueTag.epilogue_op_silu: "silu", # silu or swiglu TrtLlm_EpilogueTag.epilogue_op_gelu: "gelu" # gelu or geglu } EpiTag = { TrtLlm_EpilogueTag.epilogue_op_default: "tensorrt_llm::cutlass_extensions::EpilogueOpDefault", TrtLlm_EpilogueTag.epilogue_op_bias: "tensorrt_llm::cutlass_extensions::EpilogueOpBias", TrtLlm_EpilogueTag.epilogue_op_silu: "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultSilu", TrtLlm_EpilogueTag.epilogue_op_gelu: "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultFtGelu" } EpiFusion = { TrtLlm_EpilogueFusion.epilogue_fusion_none: "tensorrt_llm::HopperGroupedGemmInput::EpilogueFusion::NONE", TrtLlm_EpilogueFusion.epilogue_fusion_finalize: "tensorrt_llm::HopperGroupedGemmInput::EpilogueFusion::FINALIZE", } EpiFusionSuffixes = { None: "", TrtLlm_EpilogueFusion.epilogue_fusion_none: "EpilogueFusion_NONE", TrtLlm_EpilogueFusion.epilogue_fusion_finalize: "EpilogueFusion_FINALIZE", } ################################################################################ # Quantization Operation and string utils class TrtLlm_QuantOp(enum.Enum): per_column_scale_only = enum.auto() finegrained_scale_only = enum.auto() finegrained_scale_and_zeros = enum.auto() none = enum.auto() QuantOpNames = { TrtLlm_QuantOp.per_column_scale_only: "cs", TrtLlm_QuantOp.finegrained_scale_only: "fgs", TrtLlm_QuantOp.finegrained_scale_and_zeros: "fgsz", TrtLlm_QuantOp.none: "noquant" } QuantOpTag = { TrtLlm_QuantOp.per_column_scale_only: "cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY", TrtLlm_QuantOp.finegrained_scale_only: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY", TrtLlm_QuantOp.finegrained_scale_and_zeros: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS", TrtLlm_QuantOp.none: "void" } ################################################################################ # The activations, biases, scales and zeros are instantiated using CUDA types, # not CUTLASS types. This map materializes the name of the CUDA type. CudaTypeName = { DataType.e4m3: "__nv_fp8_e4m3", DataType.bf16: "__nv_bfloat16", DataType.f16: "half", DataType.f32: "float" } ################################################################################ # A data structure holding all info to instantiate gemm launchers in TRT LLM. class TrtLlm_GemmLauncher: def __init__(self, gemm_kind, arch, act_type, weight_type, scalezero_type, bias_type, output_type, quant_op, epi_tag, cta_shape, warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule, epi_fusion=None): self.gemm_kind = gemm_kind self.arch = arch self.act_type = act_type self.weight_type = weight_type self.scalezero_type = scalezero_type self.bias_type = bias_type self.output_type = output_type self.quant_op = quant_op self.epi_tag = epi_tag self.cta_shape = cta_shape self.warp_shape = warp_shape self.stages = stages self.cga_shape = cga_shape self.mainloop_schedule = mainloop_schedule self.epi_schedule = epi_schedule self.epi_fusion = epi_fusion def __repr__(self): kernel_prefix = "{}_sm{}_{}_{}_{}_{}_{}_{}_{}_{}x{}x{}_{}x{}x{}_{}".format( GemmKindNames[self.gemm_kind], self.arch, DataTypeNames[self.act_type], DataTypeNames[self.weight_type], DataTypeNames[self.scalezero_type], DataTypeNames[self.bias_type], DataTypeNames[self.output_type], QuantOpNames[self.quant_op], EpiTagNames[self.epi_tag], self.cta_shape[0], self.cta_shape[1], self.cta_shape[2], self.warp_shape[0], self.warp_shape[1], self.warp_shape[2], self.stages) hopper_suffix = "_{}x{}x{}{}{}{}".format( self.cga_shape[0], self.cga_shape[1], self.cga_shape[2], KernelScheduleSuffixes[self.mainloop_schedule], EpilogueScheduleSuffixes[self.epi_schedule], EpiFusionSuffixes[self.epi_fusion]) if self.arch == 90: return kernel_prefix + hopper_suffix elif self.arch > 90: raise ValueError(f"SM{self.arch} not supported yet.") return kernel_prefix ################################################################################ def tuple_to_cute_shape(shape): return f"cute::Shape<cute::Int<{shape[0]}>, cute::Int<{shape[1]}>, cute::Int<{shape[2]}>>" def instantiate_operation_sm90(operation): act_tag = CudaTypeName[operation.act_type] scale_zero_tag = CudaTypeName[operation.scalezero_type] bias_tag = CudaTypeName[operation.bias_type] out_tag = CudaTypeName[operation.output_type] quant_op = QuantOpTag[operation.quant_op] epi_tag = EpiTag[operation.epi_tag] cute_cta_shape = tuple_to_cute_shape(operation.cta_shape) cute_cga_shape = tuple_to_cute_shape(operation.cga_shape) kernel_sched = KernelScheduleTag[operation.mainloop_schedule] epi_sched = EpilogueScheduleTag[operation.epi_schedule] if operation.gemm_kind == GemmKind.Gemm: if operation.mainloop_schedule in [ KernelScheduleType.TmaWarpSpecializedCooperative, KernelScheduleType.TmaWarpSpecializedPingpong, KernelScheduleType.TmaWarpSpecialized ] and DataTypeSize[operation.act_type] != DataTypeSize[ operation.weight_type]: # Here, we must append MixedInput depending on the schedule, since we know the types are different. # It is a work around since the CUTLASS library did not have the MixedInput schedules at the time of writing. kernel_sched += "MixedInput" weight_tag = DataTypeTag[operation.weight_type] instantiation = f"""template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag}, {quant_op}, {epi_tag}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}> (const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float, {out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*); """ elif operation.gemm_kind == GemmKind.Grouped: # Similar to MixedInput above, we must modify the tags for grouped gemm as CUTLASS library does not have the updated schedules assert operation.mainloop_schedule in [ KernelScheduleType.TmaWarpSpecializedCooperative, KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum ] assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized kernel_sched.replace("::Kernel", "::KernelGrouped") epi_sched += "Grouped" weight_tag = CudaTypeName[operation.weight_type] assert operation.epi_fusion is not None epi_fusion = EpiFusion[operation.epi_fusion] instantiation = f"""template void sm90_generic_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {out_tag}, {epi_tag}, {epi_fusion}, {cute_cta_shape}, {cute_cga_shape}, false>(HopperGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);""" return instantiation def instantiate_operation_sm80(operation): act_tag = DataTypeTag[operation.dtype] weight_tag = DataTypeTag[operation.dtype] epi_tag = EpiTag[operation.epi_tag] instantiation = f"""template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.stage}, {epi_tag}>({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);""" return instantiation def instantiate_operation(operation): if operation.arch == 80: return instantiate_operation_sm80(operation) elif operation.arch == 90: return instantiate_operation_sm90(operation) def is_gemm_op_valid(op): tile_m, tile_n, _ = op.cta_shape cga_m, cga_n, _ = op.cga_shape if cga_m == 1 and cga_n == 1: return True if cga_m == 2 and cga_n == 1 and tile_m >= 128: return True if cga_m == 1 and cga_n == 2 and tile_n >= 128: return True if cga_m == 2 and cga_n == 2 and tile_m >= 128 and tile_n >= 128: return True return False def is_grouped_gemm_op_valid(op): if not is_gemm_op_valid(op): return False if op.epi_tag != TrtLlm_EpilogueTag.epilogue_op_default: return False if op.epi_schedule != EpilogueScheduleType.NoSmemWarpSpecialized: return False if op.mainloop_schedule not in [ KernelScheduleType.TmaWarpSpecializedCooperative, KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum ]: return False return True def is_op_valid(op): if op.gemm_kind == GemmKind.Gemm: return is_gemm_op_valid(op) if op.gemm_kind == GemmKind.Grouped: return is_grouped_gemm_op_valid(op) ################################################################################ def generate_sm90_mixed_gemm_operations(): arch = 90 # For legacy reasons, we use unsigned types for the weights. The instanitated template # will remap those back to the signed type. # Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type) supported_dtypes = [ (DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16), (DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16), (DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16, DataType.bf16), (DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16), (DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16, DataType.bf16) ] quant_ops = [ TrtLlm_QuantOp.per_column_scale_only, TrtLlm_QuantOp.finegrained_scale_only, TrtLlm_QuantOp.finegrained_scale_and_zeros ] epi_tags = [TrtLlm_EpilogueTag.epilogue_op_bias] M_TILES = [64, 128] N_TILES = [16, 32, 64, 128, 256] cta_shapes_mn = product(M_TILES, N_TILES) warp_shape = [4, 1, 1] stages = 0 # auto cga_shapes = product([1, 2], [1, 2], [1]) partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn, cga_shapes) operations = list() for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args: max_k_bits = 128 * 8 cta_shape_k = max_k_bits // DataTypeSize[dtype_combo[0]] cta_shape_mnk = cta_shape_mn + (cta_shape_k, ) use_coop = cta_shape_mn[0] == 128 mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative if use_coop else KernelScheduleType.TmaWarpSpecializedPingpong epi_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative if use_coop else EpilogueScheduleType.TmaWarpSpecialized fpA_intB_operation = TrtLlm_GemmLauncher(GemmKind.Gemm, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \ warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule) if is_op_valid(fpA_intB_operation): operations.append(fpA_intB_operation) return operations def generate_sm90_grouped_gemm_operations(): arch = 90 supported_dtypes = [ DataType.f16, DataType.bf16, DataType.f32, DataType.e4m3 ] quant_ops = [TrtLlm_QuantOp.none] epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default] M_TILES = [128] # Currently M tile must be 128 for Grouped GEMM N_TILES = [16, 32, 64, 128, 256] cta_shapes_mn = list(product(M_TILES, N_TILES)) + [(256, 128)] warp_shape = [0, 0, 0] # ignored except for naming stages = 0 # auto epi_fusions = [ TrtLlm_EpilogueFusion.epilogue_fusion_none, TrtLlm_EpilogueFusion.epilogue_fusion_finalize ] cga_shapes = product([1, 2], [1, 2], [1]) partial_args = product(supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mn, cga_shapes) operations = list() for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args: max_k_bits = 128 * 8 cta_shape_k = max_k_bits // DataTypeSize[dtype] cta_shape_mnk = cta_shape_mn + (cta_shape_k, ) mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative if dtype != DataType.e4m3 else KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized otypes = [dtype] if dtype == DataType.e4m3: otypes = [DataType.f16, DataType.bf16] for otype in otypes: moe_gemm_operation = TrtLlm_GemmLauncher( GemmKind.Grouped, arch, dtype, dtype, dtype, dtype, otype, quant_op, epi_tag, cta_shape_mnk, warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule, epi_fusion) if is_op_valid(moe_gemm_operation): operations.append(moe_gemm_operation) return operations class GemmSm80LauncherConfig: def __init__(self, gemm_kind, arch, dtype, epi_tag, cta_shape, stage): self.gemm_kind = gemm_kind self.arch = arch self.dtype = dtype self.epi_tag = epi_tag self.cta_shape = cta_shape self.stage = stage def generate_sm80_fused_grouped_gemm_operations(): arch = 80 supported_dtypes = [DataType.f16, DataType.bf16] epi_tags = [ TrtLlm_EpilogueTag.epilogue_op_silu, TrtLlm_EpilogueTag.epilogue_op_gelu ] cta_shapes_mnk = [(16, 128, 64), (16, 256, 64), (32, 128, 64), (64, 128, 64), (128, 128, 64)] stages = [2, 3, 4] partial_args = product(supported_dtypes, epi_tags, cta_shapes_mnk, stages) operations = list() for dtype, epi_tag, cta_shape_mnk, stage in partial_args: item = GemmSm80LauncherConfig(GemmKind.Grouped, arch, dtype, epi_tag, cta_shape_mnk, stage) operations.append(item) return operations if __name__ == "__main__": fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl" moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl" sm80_moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl" inl_map = { (GemmKind.Gemm, 90): [fpA_intB_inl], (GemmKind.Grouped, 90): [moe_gemm_inl], (GemmKind.Grouped, 80): [sm80_moe_gemm_inl] } # The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads. # Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve. operations = [] # operations.extend(generate_sm90_mixed_gemm_operations()) operations.extend(generate_sm90_grouped_gemm_operations()) operations.extend(generate_sm80_fused_grouped_gemm_operations()) op_groups = dict() for op in operations: dict_key = (op.gemm_kind, op.arch, op.cta_shape[0]) op_group = op_groups.get(dict_key, list()) op_group.append(op) op_groups[dict_key] = op_group for key, ops in op_groups.items(): for op in ops: print(instantiate_operation(op))