import enum
import os
from itertools import product


################################################################################
# Epilogue Tag enum and string utils

#################################################################################################
#
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""
Data types and tags used for emitting CUTLASS C++ kernels
"""

import enum
import re

# The following block implements enum.auto() for Python 3.5 variants that don't include it such
# as the default 3.5.2 on Ubuntu 16.04.
#
# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility

try:
  from enum import auto as enum_auto
except ImportError:
  __cutlass_library_auto_enum = 0
  def enum_auto() -> int:
    global __cutlass_library_auto_enum
    i = __cutlass_library_auto_enum
    __cutlass_library_auto_enum += 1
    return i

###################################################################################################

#
class GeneratorTarget(enum.Enum):
  Library = enum_auto()
#
GeneratorTargetNames = {
  GeneratorTarget.Library: 'library'
}
#

###################################################################################################

#
class DataType(enum.Enum):
  void = enum_auto()  # primarily used to disable C tensor for epilogues
  b1 = enum_auto()
  u4 = enum_auto()
  u8 = enum_auto()
  u16 = enum_auto()
  u32 = enum_auto()
  u64 = enum_auto()
  s4 = enum_auto()
  s8 = enum_auto()
  s16 = enum_auto()
  s32 = enum_auto()
  s64 = enum_auto()
  e4m3 = enum_auto()
  e5m2 = enum_auto()
  f16 = enum_auto()
  bf16 = enum_auto()
  f32 = enum_auto()
  tf32 = enum_auto()
  f64 = enum_auto()
  cf16 = enum_auto()
  cbf16 = enum_auto()
  cf32 = enum_auto()
  ctf32 = enum_auto()
  cf64 = enum_auto()
  cs4 = enum_auto()
  cs8 = enum_auto()
  cs16 = enum_auto()
  cs32 = enum_auto()
  cs64 = enum_auto()
  cu4 = enum_auto()
  cu8 = enum_auto()
  cu16 = enum_auto()
  cu32 = enum_auto()
  cu64 = enum_auto()
  invalid = enum_auto()

#
ShortDataTypeNames = {
  DataType.s32: 'i',
  DataType.e4m3: 'e4m3',
  DataType.e5m2: 'e5m2',
  DataType.f16: 'h',
  DataType.f32: 's',
  DataType.f64: 'd',
  DataType.cf32: 'c',
  DataType.cf64: 'z',
}

#
DataTypeNames = {
  DataType.void: "void",
  DataType.b1: "b1",
  DataType.u4: "u4",
  DataType.u8: "u8",
  DataType.u16: "u16",
  DataType.u32: "u32",
  DataType.u64: "u64",
  DataType.s4: "s4",
  DataType.s8: "s8",
  DataType.s16: "s16",
  DataType.s32: "s32",
  DataType.s64: "s64",
  DataType.e4m3: 'e4m3',
  DataType.e5m2: 'e5m2',
  DataType.f16: "f16",
  DataType.bf16: "bf16",
  DataType.f32: "f32",
  DataType.tf32: "tf32",
  DataType.f64: "f64",
  DataType.cf16: "cf16",
  DataType.cbf16: "cbf16",
  DataType.cf32: "cf32",
  DataType.ctf32: "ctf32",
  DataType.cf64: "cf64",
  DataType.cu4: "cu4",
  DataType.cu8: "cu8",
  DataType.cu16: "cu16",
  DataType.cu32: "cu32",
  DataType.cu64: "cu64",
  DataType.cs4: "cs4",
  DataType.cs8: "cs8",
  DataType.cs16: "cs16",
  DataType.cs32: "cs32",
  DataType.cs64: "cs64",
}

DataTypeTag = {
  DataType.void: "void",
  DataType.b1: "cutlass::uint1b_t",
  DataType.u4: "cutlass::uint4b_t",
  DataType.u8: "uint8_t",
  DataType.u16: "uint16_t",
  DataType.u32: "uint32_t",
  DataType.u64: "uint64_t",
  DataType.s4: "cutlass::int4b_t",
  DataType.s8: "int8_t",
  DataType.s16: "int16_t",
  DataType.s32: "int32_t",
  DataType.s64: "int64_t",
  DataType.e4m3: 'cutlass::float_e4m3_t',
  DataType.e5m2: 'cutlass::float_e5m2_t',
  DataType.f16: "cutlass::half_t",
  DataType.bf16: "cutlass::bfloat16_t",
  DataType.f32: "float",
  DataType.tf32: "cutlass::tfloat32_t",
  DataType.f64: "double",
  DataType.cf16: "cutlass::complex<cutlass::half_t>",
  DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
  DataType.cf32: "cutlass::complex<float>",
  DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
  DataType.cf64: "cutlass::complex<double>",
  DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
  DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
  DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
  DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
  DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
  DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
  DataType.cs8: "cutlass::complex<cutlass::int8_t>",
  DataType.cs16: "cutlass::complex<cutlass::int16_t>",
  DataType.cs32: "cutlass::complex<cutlass::int32_t>",
  DataType.cs64: "cutlass::complex<cutlass::int64_t>",
}

DataTypeSize = {
  DataType.void: 0,
  DataType.b1: 1,
  DataType.u4: 4,
  DataType.u8: 8,
  DataType.u16: 16,
  DataType.u32: 32,
  DataType.u64: 64,
  DataType.s4: 4,
  DataType.s8: 8,
  DataType.s16: 16,
  DataType.s32: 32,
  DataType.s64: 64,
  DataType.e4m3: 8,
  DataType.e5m2: 8,
  DataType.f16: 16,
  DataType.bf16: 16,
  DataType.f32: 32,
  DataType.tf32: 32,
  DataType.f64: 64,
  DataType.cf16: 32,
  DataType.cbf16: 32,
  DataType.cf32: 64,
  DataType.ctf32: 32,
  DataType.cf64: 128,
  DataType.cu4: 8,
  DataType.cu8: 16,
  DataType.cu16: 32,
  DataType.cu32: 64,
  DataType.cu64: 128,
  DataType.cs4: 8,
  DataType.cs8: 16,
  DataType.cs16: 32,
  DataType.cs32: 64,
  DataType.cs64: 128,
}

###################################################################################################
#
class BlasMode(enum.Enum):
  symmetric = enum_auto()
  hermitian = enum_auto()

#
BlasModeTag = {
  BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric',
  BlasMode.hermitian: 'cutlass::BlasMode::kHermitian',
}

#
class ComplexTransform(enum.Enum):
  none = enum_auto()
  conj = enum_auto()

#
ComplexTransformTag = {
  ComplexTransform.none: 'cutlass::ComplexTransform::kNone',
  ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
}

# Used for cutlass3x complex kernel collective mainloop builder instantiation
ComplexTransformTag3x = {
  ComplexTransform.none: 'cute::identity',
  ComplexTransform.conj: 'cute::conjugate',
}

#
RealComplexBijection = [
  (DataType.f16, DataType.cf16),
  (DataType.f32, DataType.cf32),
  (DataType.f64, DataType.cf64),
]

#
def is_complex(data_type):
  for r, c in RealComplexBijection:
    if data_type == c:
      return True
  return False

#
def get_complex_from_real(real_type):
  for r, c in RealComplexBijection:
    if real_type == r:
      return c
  return DataType.invalid

#
def get_real_from_complex(complex_type):
  for r, c in RealComplexBijection:
    if complex_type == c:
      return r
  return DataType.invalid

#
class ComplexMultiplyOp(enum.Enum):
  multiply_add = enum_auto()
  gaussian = enum_auto()

###################################################################################################

#
class MathOperation(enum.Enum):
  multiply_add = enum_auto()
  multiply_add_saturate = enum_auto()
  multiply_add_mixed_input_upcast = enum_auto()
  xor_popc = enum_auto()
  and_popc = enum_auto()
  multiply_add_fast_bf16 = enum_auto()
  multiply_add_fast_f16 = enum_auto()
  multiply_add_fast_f32 = enum_auto()
  multiply_add_complex_fast_f32 = enum_auto()
  multiply_add_complex = enum_auto()
  multiply_add_complex_gaussian = enum_auto()
  multiply_add_fast_accum = enum_auto()

#
MathOperationTag = {
  MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
  MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
  MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast',
  MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
  MathOperation.and_popc: 'cutlass::arch::OpAndPopc',
  MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
  MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16',
  MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32',
  MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32',
  MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex',
  MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex',
  MathOperation.multiply_add_fast_accum: 'cutlass::arch::OpMultiplyAddFastAccum',
}

###################################################################################################

#
class LayoutType(enum.Enum):
  ColumnMajor = enum_auto()
  RowMajor = enum_auto()
  ColumnMajorInterleaved2 = enum_auto()
  RowMajorInterleaved2 = enum_auto()
  ColumnMajorInterleaved32 = enum_auto()
  RowMajorInterleaved32 = enum_auto()
  ColumnMajorInterleaved64 = enum_auto()
  RowMajorInterleaved64 = enum_auto()
  TensorNWC = enum_auto()
  TensorNHWC = enum_auto()
  TensorNDHWC = enum_auto()
  TensorNCHW = enum_auto()
  TensorNGHWC = enum_auto()
  TensorNC32HW32 = enum_auto()
  TensorNC64HW64 = enum_auto()
  TensorC32RSK32 = enum_auto()
  TensorC64RSK64 = enum_auto()
  TensorKCS = enum_auto()
  TensorKCSR = enum_auto()
  TensorKCSRT = enum_auto()

#
LayoutTag = {
  LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor',
  LayoutType.RowMajor: 'cutlass::layout::RowMajor',
  LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>',
  LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>',
  LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>',
  LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>',
  LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>',
  LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>',
  LayoutType.TensorNWC: 'cutlass::layout::TensorNWC',
  LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC',
  LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC',
  LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW',
  LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC',
  LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>',
  LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>',
  LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>',
  LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>',
  LayoutType.TensorKCS: 'cutlass::layout::TensorKCS',
  LayoutType.TensorKCSR: 'cutlass::layout::TensorKCSR',
  LayoutType.TensorKCSRT: 'cutlass::layout::TensorKCSRT'
}

#
TransposedLayout = {
  LayoutType.ColumnMajor: LayoutType.RowMajor,
  LayoutType.RowMajor: LayoutType.ColumnMajor,
  LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
  LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
  LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
  LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
  LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
  LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
  LayoutType.TensorNHWC: LayoutType.TensorNHWC
}

#
ShortLayoutTypeNames = {
  LayoutType.ColumnMajor: 'n',
  LayoutType.ColumnMajorInterleaved2: 'n2',
  LayoutType.ColumnMajorInterleaved32: 'n32',
  LayoutType.ColumnMajorInterleaved64: 'n64',
  LayoutType.RowMajor: 't',
  LayoutType.RowMajorInterleaved2: 't2',
  LayoutType.RowMajorInterleaved32: 't32',
  LayoutType.RowMajorInterleaved64: 't64',
  LayoutType.TensorNWC: 'nwc',
  LayoutType.TensorNHWC: 'nhwc',
  LayoutType.TensorNDHWC: 'ndhwc',
  LayoutType.TensorNCHW: 'nchw',
  LayoutType.TensorNGHWC: 'nghwc',
  LayoutType.TensorNC32HW32: 'nc32hw32',
  LayoutType.TensorNC64HW64: 'nc64hw64',
  LayoutType.TensorC32RSK32: 'c32rsk32',
  LayoutType.TensorC64RSK64: 'c64rsk64',
  LayoutType.TensorKCS: 'kcs',
  LayoutType.TensorKCSR: 'kcsr',
  LayoutType.TensorKCSRT: 'kcsrt'
}

#
ShortComplexLayoutNames = {
  (LayoutType.ColumnMajor, ComplexTransform.none): 'n',
  (LayoutType.ColumnMajor, ComplexTransform.conj): 'c',
  (LayoutType.RowMajor, ComplexTransform.none): 't',
  (LayoutType.RowMajor, ComplexTransform.conj): 'h'
}

###################################################################################################
class KernelScheduleType(enum.Enum):
  ScheduleAuto = enum_auto()
  Multistage = enum_auto()
  CpAsyncWarpSpecialized = enum_auto()
  CpAsyncWarpSpecializedPingpong = enum_auto()
  CpAsyncWarpSpecializedCooperative = enum_auto()
  Tma = enum_auto()
  TmaWarpSpecialized = enum_auto()
  TmaWarpSpecializedPingpong = enum_auto()
  TmaWarpSpecializedCooperative = enum_auto()
  TmaWarpSpecializedFP8FastAccum = enum_auto()
  TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
  TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
  ImplicitTmaWarpSpecializedSm90 = enum_auto()
#
KernelScheduleTag = {
  KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
  KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage',
  KernelScheduleType.CpAsyncWarpSpecialized: 'cutlass::gemm::KernelCpAsyncWarpSpecialized',
  KernelScheduleType.CpAsyncWarpSpecializedPingpong: 'cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong',
  KernelScheduleType.CpAsyncWarpSpecializedCooperative: 'cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative',
  KernelScheduleType.Tma: 'cutlass::gemm::KernelTma',
  KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized',
  KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong',
  KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative',
  KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum',
  KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum',
  KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum',
  KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90',
}

#
KernelScheduleSuffixes = {
  KernelScheduleType.ScheduleAuto: '',
  KernelScheduleType.Multistage: '_cpasync',
  KernelScheduleType.CpAsyncWarpSpecialized: '_cpasync_warpspecialized',
  KernelScheduleType.CpAsyncWarpSpecializedPingpong: '_cpasync_warpspecialized_pingpong',
  KernelScheduleType.CpAsyncWarpSpecializedCooperative: '_cpasync_warpspecialized_cooperative',
  KernelScheduleType.Tma: '_unspecialized',
  KernelScheduleType.TmaWarpSpecialized: '_warpspecialized',
  KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
  KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
  KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum',
  KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
  KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
  KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized',
}

class EpilogueScheduleType(enum.Enum):
  ScheduleAuto = enum_auto()
  EpilogueTransposed = enum_auto()
  NoSmemWarpSpecialized = enum_auto()
  TmaWarpSpecialized = enum_auto()
  TmaWarpSpecializedCooperative = enum_auto()
#
EpilogueScheduleTag = {
  EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto',
  EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed',
  EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized',
  EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
  EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
}

#
EpilogueScheduleSuffixes = {
  EpilogueScheduleType.ScheduleAuto: '',
  EpilogueScheduleType.EpilogueTransposed: '',
  EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem',
  EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
  EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
}

class EpilogueFunctor3x(enum.Enum):
  LinearCombination = enum_auto()
#
EpilogueFunctor3xTag = {
  EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination',
}

class TileSchedulerType(enum.Enum):
  Default = enum_auto()
  Persistent = enum_auto()
  StreamK = enum_auto()
#
TileSchedulerTag = {
  TileSchedulerType.Default: 'void',
  TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler',
  TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler',
}

#
TileSchedulerSuffixes = {
  TileSchedulerType.Default: '',
  TileSchedulerType.Persistent: '',
  TileSchedulerType.StreamK: '_stream_k',
}

###################################################################################################

#
class SideMode(enum.Enum):
  Left = enum_auto()
  Right = enum_auto()

#
SideModeTag = {
  SideMode.Left: 'cutlass::SideMode::kLeft',
  SideMode.Right: 'cutlass::SideMode::kRight'
}

#
ShortSideModeNames = {
  SideMode.Left: 'ls',
  SideMode.Right: 'rs'
}

###################################################################################################

#
class FillMode(enum.Enum):
  Lower = enum_auto()
  Upper = enum_auto()

#
FillModeTag = {
  FillMode.Lower: 'cutlass::FillMode::kLower',
  FillMode.Upper: 'cutlass::FillMode::kUpper'
}

#
ShortFillModeNames = {
  FillMode.Lower: 'l',
  FillMode.Upper: 'u'
}

###################################################################################################

#
class DiagType(enum.Enum):
  NonUnit = enum_auto()
  Unit = enum_auto()

#
DiagTypeTag = {
  DiagType.NonUnit: 'cutlass::DiagType::kNonUnit',
  DiagType.Unit: 'cutlass::DiagType::kUnit'
}

#
ShortDiagTypeNames = {
  DiagType.NonUnit: 'nu',
  DiagType.Unit: 'un'
}

###################################################################################################

#
class OpcodeClass(enum.Enum):
  Simt = enum_auto()
  TensorOp = enum_auto()
  WmmaTensorOp = enum_auto()
  SparseTensorOp = enum_auto()

OpcodeClassNames = {
  OpcodeClass.Simt: 'simt',
  OpcodeClass.TensorOp: 'tensorop',
  OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
}

OpcodeClassTag = {
  OpcodeClass.Simt: 'cutlass::arch::OpClassSimt',
  OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
  OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
}

###################################################################################################

#
class OperationKind(enum.Enum):
  Gemm = enum_auto()
  RankK = enum_auto()
  Rank2K = enum_auto()
  Trmm = enum_auto()
  Symm = enum_auto()
  Conv2d = enum_auto()
  Conv3d = enum_auto()

#
OperationKindNames = {
  OperationKind.Gemm: 'gemm'
  , OperationKind.RankK: 'rank_k'
  , OperationKind.Rank2K: 'rank_2k'
  , OperationKind.Trmm: 'trmm'
  , OperationKind.Symm: 'symm'
  , OperationKind.Conv2d: 'conv2d'
  , OperationKind.Conv3d: 'conv3d'
}

#
class Target(enum.Enum):
  library = enum_auto()
#
ArchitectureNames = {
  50: 'maxwell',
  60: 'pascal',
  61: 'pascal',
  70: 'volta',
  75: 'turing',
  80: 'ampere',
  89: 'ada',
  90: 'hopper'
}

#
SharedMemPerCC = {
  70:  96, #  96KB of SMEM
  72:  96, #  96KB of SMEM
  75:  64, #  64KB of SMEM
  80: 163, # 163KB of SMEM - 1KB reserved for the driver
  86:  99, #  99KB of SMEM - 1KB reserved for the driver
  87: 163, # 163KB of SMEM - 1KB reserved for the driver
  89:  99, #  99KB of SMEM - 1KB reserved for the driver
  90: 227, # 227KB of SMEM - 1KB reserved for the driver
}

###################################################################################################

#
def SubstituteTemplate(template, values):
  text = template
  changed = True
  while changed:
    changed = False
    for key, value in values.items():
      regex = "\\$\\{%s\\}" % key
      newtext = re.sub(regex, value, text)
      if newtext != text:
        changed = True
      text = newtext
  return text

###################################################################################################

#
class GemmKind(enum.Enum):
  Gemm = enum_auto()
  Sparse = enum_auto()
  Universal = enum_auto()
  Universal3x = enum_auto()
  SparseUniversal3x = enum_auto()
  PlanarComplex = enum_auto()
  PlanarComplexArray = enum_auto()
  Grouped = enum_auto()
#
GemmKindNames = {
  GemmKind.Gemm: "gemm",
  GemmKind.Sparse: "spgemm",
  GemmKind.Universal: "gemm",
  GemmKind.Universal3x: "gemm",
  GemmKind.SparseUniversal3x: "spgemm",
  GemmKind.PlanarComplex: "gemm_planar_complex",
  GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
  GemmKind.Grouped: "gemm_grouped",
}

#
class RankKKind(enum.Enum):
  Universal = enum_auto()

#
RankKKindNames = {
  RankKKind.Universal: "rank_k"
}

#
class TrmmKind(enum.Enum):
  Universal = enum_auto()

#
TrmmKindNames = {
  TrmmKind.Universal: "trmm"
}

#
class SymmKind(enum.Enum):
  Universal = enum_auto()

#
SymmKindNames = {
  SymmKind.Universal: "symm"
}

#
class EpilogueFunctor(enum.Enum):
  LinearCombination = enum_auto()
  LinearCombinationClamp = enum_auto()

#
EpilogueFunctorTag = {
  EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
  EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
}

#
class SwizzlingFunctor(enum.Enum):
  Identity1 = enum_auto()
  Identity2 = enum_auto()
  Identity4 = enum_auto()
  Identity8 = enum_auto()
  Horizontal = enum_auto()
  StridedDgradIdentity1 = enum_auto()
  StridedDgradIdentity4 = enum_auto()
  StridedDgradHorizontal = enum_auto()
  StreamK = enum_auto()

#
SwizzlingFunctorTag = {
  SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>',
  SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>',
  SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>',
  SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>',
  SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle',
  SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>',
  SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>',
  SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle',
  SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK',
}

#
class GroupScheduleMode(enum.Enum):
  Device = enum_auto(),
  Host = enum_auto()

#
GroupScheduleModeTag = {
  GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly',
  GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute'
}

#
ShortGroupScheduleModeNames = {
  GroupScheduleMode.Device: 'Device',
  GroupScheduleMode.Host: 'Host'
}

###################################################################################################

#
class ConvKind(enum.IntEnum):
  Fprop = 0
  Dgrad = 1
  Wgrad = 2

#
ConvKindTag = {
  ConvKind.Fprop: 'cutlass::conv::Operator::kFprop',
  ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad',
  ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad'
}

ConvKindNames = {
  ConvKind.Fprop: 'fprop',
  ConvKind.Dgrad: 'dgrad',
  ConvKind.Wgrad: 'wgrad',
}

class ConvMode(enum.IntEnum):
  CrossCorrelation = 0
  Convolution = 1

#
class IteratorAlgorithm(enum.Enum):
  Analytic = 0
  Optimized = 1
  FixedChannels = 2
  FewChannels = 3
  FixedStrideDilation = 4

#
IteratorAlgorithmTag = {
  IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
  IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
  IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels',
  IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels',
  IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation'
}

IteratorAlgorithmNames = {
  IteratorAlgorithm.Analytic: 'analytic',
  IteratorAlgorithm.Optimized: 'optimized',
  IteratorAlgorithm.FixedChannels: 'fixed_channels',
  IteratorAlgorithm.FewChannels: 'few_channels',
  IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation'
}

#
class StrideSupport(enum.Enum):
  Strided = 0
  Unity = 1
  Fixed = 2

#
StrideSupportTag = {
  StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
  StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
  StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed'
}

StrideSupportNames = {
  StrideSupport.Strided: '',
  StrideSupport.Unity: 'unity_stride',
  StrideSupport.Fixed: 'fixed_stride'
}

#
class GroupMode(enum.Enum):
  NoneGroup = enum_auto()         # dense conv (G=1)
  SingleGroup = enum_auto()       # grouped convolution (single group per CTA)
  MultipleGroup = enum_auto()     # grouped convolution ( multiple groups per CTA)
  Depthwise = enum_auto()         # Depthwise convolution ( C=K=G )

#
GroupModeTag = {
  GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone',
  GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup',
  GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup',
  GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise',
}

GroupModeNames = {
  GroupMode.NoneGroup: '',
  GroupMode.SingleGroup: 'single_group',
  GroupMode.MultipleGroup: 'multiple_group',
  GroupMode.Depthwise: 'depthwise',
}

###################################################################################################

#
class MathInstruction:
  def __init__(self,
      instruction_shape,                                            \
      element_a, element_b, element_accumulator,                    \
      opcode_class, math_operation = MathOperation.multiply_add     \
    ):

    self.instruction_shape = instruction_shape
    self.element_a = element_a
    self.element_b = element_b
    self.element_accumulator = element_accumulator
    self.opcode_class = opcode_class
    self.math_operation = math_operation
#
class TileDescription:

  def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1]):
    self.threadblock_shape = threadblock_shape
    self.tile_shape = threadblock_shape
    self.stages = stages
    self.warp_count = warp_count
    self.math_instruction = math_instruction
    self.minimum_compute_capability = min_compute
    self.maximum_compute_capability = max_compute
    self.cluster_shape = cluster_shape

  def procedural_name(self):
    if self.minimum_compute_capability >= 90:
      return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format(
        tbm = self.threadblock_shape[0],
        tbn = self.threadblock_shape[1],
        tbk = self.threadblock_shape[2],
        cm = self.cluster_shape[0],
        cn = self.cluster_shape[1],
        ck = self.cluster_shape[2],
        s = self.stages)
    else:
      return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)

#
class Direct2dConvFixedStrideDilationTileDescription:
  def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
    self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
    self.threadblock_output_shape = threadblock_output_shape
    self.filter_shape = filter_shape
    self.stages = stages
    self.warp_count = warp_count
    self.stride = stride
    self.dilation =  dilation
    self.math_instruction = math_instruction
    self.minimum_compute_capability = min_compute
    self.maximum_compute_capability = max_compute

  def procedural_name(self):
    str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
                                      self.threadblock_shape[1],
                                      self.threadblock_shape[2],
                                      self.threadblock_output_shape[0],
                                      self.threadblock_output_shape[1],
                                      self.threadblock_output_shape[2],
                                      self.threadblock_output_shape[3],
                                      self.stages,
                                      self.filter_shape[0],
                                      self.filter_shape[1])
    # Fixed Strided and dilation
    if self.stride != [-1, -1] and self.dilation != [-1, -1]:
      str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
                                                  self.stride[1],
                                                  self.dilation[0],
                                                  self.dilation[1])
    return str_name

#
class Direct2dConvFixedStrideDilationTileDescription:
  def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
    self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
    self.threadblock_output_shape = threadblock_output_shape
    self.filter_shape = filter_shape
    self.stages = stages
    self.warp_count = warp_count
    self.stride = stride
    self.dilation =  dilation
    self.math_instruction = math_instruction
    self.minimum_compute_capability = min_compute
    self.maximum_compute_capability = max_compute

  def procedural_name(self):
    str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
                                      self.threadblock_shape[1],
                                      self.threadblock_shape[2],
                                      self.threadblock_output_shape[0],
                                      self.threadblock_output_shape[1],
                                      self.threadblock_output_shape[2],
                                      self.threadblock_output_shape[3],
                                      self.stages,
                                      self.filter_shape[0],
                                      self.filter_shape[1])
    # Fixed Strided and dilation
    if self.stride != [-1, -1] and self.dilation != [-1, -1]:
      str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
                                                  self.stride[1],
                                                  self.dilation[0],
                                                  self.dilation[1])
    return str_name

#
class TensorDescription:
  def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):
    self.element = element
    self.layout = layout
    self.alignment = alignment
    self.complex_transform = complex_transform

#
class SymmetricTensorDescription:
  def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left):
    self.element = element
    self.layout = layout
    self.fill_mode = fill_mode
    self.alignment = alignment
    self.complex_transform = complex_transform
    self.side_mode = side_mode

#
class TriangularTensorDescription:
  def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none):
    self.element = element
    self.layout = layout
    self.side_mode = side_mode
    self.fill_mode = fill_mode
    self.diag_type = diag_type
    self.alignment = alignment
    self.complex_transform = complex_transform

#
def CalculateSmemUsage(operation):
  cta_shape = operation.tile_description.threadblock_shape
  stages = operation.tile_description.stages

  if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse:
    # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity)
    if DataTypeSize[operation.A.element] == 32:
      elements_per_8b_md = 2
    elif DataTypeSize[operation.A.element] == 4:
      elements_per_8b_md = 8
    else:
      elements_per_8b_md = 4

    smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \
                     DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \
                     cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md
  else:
    # Few BLAS3 operations only have A tensor
    data_type_size_a = DataTypeSize[operation.A.element]
    data_type_size_b = DataTypeSize[operation.A.element]
    if operation.is_mixed_input():
      data_type_size_b = DataTypeSize[operation.B.element]

    smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \
                     data_type_size_b * cta_shape[1] * cta_shape[2] // 8

  smem_usage = smem_per_stage * stages
  return (smem_usage >> 10)


class GemmUniversalMode(enum.IntEnum):
  """
  Types corresponding to GemmUniversalMode
  """
  Gemm = 0
  GemmSplitKParallel = 1
  Batched = 2
  Array = 3


class SplitKMode(enum.IntEnum):
  """
  Types corresponding to SplitKMode
  """
  NoneSplitK = 0
  Serial = 1
  Parallel = 2


class TrtLlm_EpilogueTag(enum.Enum):
    epilogue_op_default = enum.auto()
    epilogue_op_bias = enum.auto()
    epilogue_op_silu = enum.auto()
    epilogue_op_gelu = enum.auto()


class TrtLlm_EpilogueFusion(enum.Enum):
    epilogue_fusion_none = enum.auto()
    epilogue_fusion_finalize = enum.auto()


EpiTagNames = {
    TrtLlm_EpilogueTag.epilogue_op_default: "lc",  # linear combination
    TrtLlm_EpilogueTag.epilogue_op_bias:
    "lc_bias",  # linear combination with bias addition
    TrtLlm_EpilogueTag.epilogue_op_silu: "silu",  # silu or swiglu
    TrtLlm_EpilogueTag.epilogue_op_gelu: "gelu"  # gelu or geglu
}

EpiTag = {
    TrtLlm_EpilogueTag.epilogue_op_default:
    "tensorrt_llm::cutlass_extensions::EpilogueOpDefault",
    TrtLlm_EpilogueTag.epilogue_op_bias:
    "tensorrt_llm::cutlass_extensions::EpilogueOpBias",
    TrtLlm_EpilogueTag.epilogue_op_silu:
    "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultSilu",
    TrtLlm_EpilogueTag.epilogue_op_gelu:
    "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultFtGelu"
}

EpiFusion = {
    TrtLlm_EpilogueFusion.epilogue_fusion_none:
    "tensorrt_llm::HopperGroupedGemmInput::EpilogueFusion::NONE",
    TrtLlm_EpilogueFusion.epilogue_fusion_finalize:
    "tensorrt_llm::HopperGroupedGemmInput::EpilogueFusion::FINALIZE",
}

EpiFusionSuffixes = {
    None: "",
    TrtLlm_EpilogueFusion.epilogue_fusion_none: "EpilogueFusion_NONE",
    TrtLlm_EpilogueFusion.epilogue_fusion_finalize: "EpilogueFusion_FINALIZE",
}


################################################################################
# Quantization Operation and string utils
class TrtLlm_QuantOp(enum.Enum):
    per_column_scale_only = enum.auto()
    finegrained_scale_only = enum.auto()
    finegrained_scale_and_zeros = enum.auto()
    none = enum.auto()


QuantOpNames = {
    TrtLlm_QuantOp.per_column_scale_only: "cs",
    TrtLlm_QuantOp.finegrained_scale_only: "fgs",
    TrtLlm_QuantOp.finegrained_scale_and_zeros: "fgsz",
    TrtLlm_QuantOp.none: "noquant"
}

QuantOpTag = {
    TrtLlm_QuantOp.per_column_scale_only:
    "cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY",
    TrtLlm_QuantOp.finegrained_scale_only:
    "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY",
    TrtLlm_QuantOp.finegrained_scale_and_zeros:
    "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS",
    TrtLlm_QuantOp.none: "void"
}

################################################################################
# The activations, biases, scales and zeros are instantiated using CUDA types,
# not CUTLASS types. This map materializes the name of the CUDA type.
CudaTypeName = {
    DataType.e4m3: "__nv_fp8_e4m3",
    DataType.bf16: "__nv_bfloat16",
    DataType.f16: "half",
    DataType.f32: "float"
}


################################################################################
# A data structure holding all info to instantiate gemm launchers in TRT LLM.
class TrtLlm_GemmLauncher:

    def __init__(self,
                 gemm_kind,
                 arch,
                 act_type,
                 weight_type,
                 scalezero_type,
                 bias_type,
                 output_type,
                 quant_op,
                 epi_tag,
                 cta_shape,
                 warp_shape,
                 stages,
                 cga_shape,
                 mainloop_schedule,
                 epi_schedule,
                 epi_fusion=None):
        self.gemm_kind = gemm_kind
        self.arch = arch
        self.act_type = act_type
        self.weight_type = weight_type
        self.scalezero_type = scalezero_type
        self.bias_type = bias_type
        self.output_type = output_type
        self.quant_op = quant_op
        self.epi_tag = epi_tag
        self.cta_shape = cta_shape
        self.warp_shape = warp_shape
        self.stages = stages
        self.cga_shape = cga_shape
        self.mainloop_schedule = mainloop_schedule
        self.epi_schedule = epi_schedule
        self.epi_fusion = epi_fusion

    def __repr__(self):
        kernel_prefix = "{}_sm{}_{}_{}_{}_{}_{}_{}_{}_{}x{}x{}_{}x{}x{}_{}".format(
            GemmKindNames[self.gemm_kind], self.arch,
            DataTypeNames[self.act_type], DataTypeNames[self.weight_type],
            DataTypeNames[self.scalezero_type], DataTypeNames[self.bias_type],
            DataTypeNames[self.output_type], QuantOpNames[self.quant_op],
            EpiTagNames[self.epi_tag], self.cta_shape[0], self.cta_shape[1],
            self.cta_shape[2], self.warp_shape[0], self.warp_shape[1],
            self.warp_shape[2], self.stages)

        hopper_suffix = "_{}x{}x{}{}{}{}".format(
            self.cga_shape[0], self.cga_shape[1], self.cga_shape[2],
            KernelScheduleSuffixes[self.mainloop_schedule],
            EpilogueScheduleSuffixes[self.epi_schedule],
            EpiFusionSuffixes[self.epi_fusion])

        if self.arch == 90:
            return kernel_prefix + hopper_suffix
        elif self.arch > 90:
            raise ValueError(f"SM{self.arch} not supported yet.")
        return kernel_prefix


################################################################################
def tuple_to_cute_shape(shape):
    return f"cute::Shape<cute::Int<{shape[0]}>, cute::Int<{shape[1]}>, cute::Int<{shape[2]}>>"


def instantiate_operation_sm90(operation):
    act_tag = CudaTypeName[operation.act_type]
    scale_zero_tag = CudaTypeName[operation.scalezero_type]
    bias_tag = CudaTypeName[operation.bias_type]
    out_tag = CudaTypeName[operation.output_type]

    quant_op = QuantOpTag[operation.quant_op]
    epi_tag = EpiTag[operation.epi_tag]

    cute_cta_shape = tuple_to_cute_shape(operation.cta_shape)
    cute_cga_shape = tuple_to_cute_shape(operation.cga_shape)

    kernel_sched = KernelScheduleTag[operation.mainloop_schedule]
    epi_sched = EpilogueScheduleTag[operation.epi_schedule]

    if operation.gemm_kind == GemmKind.Gemm:
        if operation.mainloop_schedule in [
                KernelScheduleType.TmaWarpSpecializedCooperative,
                KernelScheduleType.TmaWarpSpecializedPingpong,
                KernelScheduleType.TmaWarpSpecialized
        ] and DataTypeSize[operation.act_type] != DataTypeSize[
                operation.weight_type]:
            # Here, we must append MixedInput depending on the schedule, since we know the types are different.
            # It is a work around since the CUTLASS library did not have the MixedInput schedules at the time of writing.
            kernel_sched += "MixedInput"

        weight_tag = DataTypeTag[operation.weight_type]
        instantiation = f"""template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag}, {quant_op}, {epi_tag}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}> (const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float, {out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*);
"""
    elif operation.gemm_kind == GemmKind.Grouped:
        # Similar to MixedInput above, we must modify the tags for grouped gemm as CUTLASS library does not have the updated schedules
        assert operation.mainloop_schedule in [
            KernelScheduleType.TmaWarpSpecializedCooperative,
            KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
        ]
        assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
        kernel_sched.replace("::Kernel", "::KernelGrouped")
        epi_sched += "Grouped"

        weight_tag = CudaTypeName[operation.weight_type]
        assert operation.epi_fusion is not None
        epi_fusion = EpiFusion[operation.epi_fusion]

        instantiation = f"""template void sm90_generic_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {out_tag}, {epi_tag}, {epi_fusion}, {cute_cta_shape}, {cute_cga_shape}, false>(HopperGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);"""
    return instantiation


def instantiate_operation_sm80(operation):
    act_tag = DataTypeTag[operation.dtype]
    weight_tag = DataTypeTag[operation.dtype]
    epi_tag = EpiTag[operation.epi_tag]

    instantiation = f"""template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.stage}, {epi_tag}>({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);"""
    return instantiation


def instantiate_operation(operation):
    if operation.arch == 80:
        return instantiate_operation_sm80(operation)
    elif operation.arch == 90:
        return instantiate_operation_sm90(operation)


def is_gemm_op_valid(op):
    tile_m, tile_n, _ = op.cta_shape
    cga_m, cga_n, _ = op.cga_shape

    if cga_m == 1 and cga_n == 1:
        return True

    if cga_m == 2 and cga_n == 1 and tile_m >= 128:
        return True

    if cga_m == 1 and cga_n == 2 and tile_n >= 128:
        return True

    if cga_m == 2 and cga_n == 2 and tile_m >= 128 and tile_n >= 128:
        return True

    return False


def is_grouped_gemm_op_valid(op):
    if not is_gemm_op_valid(op):
        return False

    if op.epi_tag != TrtLlm_EpilogueTag.epilogue_op_default:
        return False

    if op.epi_schedule != EpilogueScheduleType.NoSmemWarpSpecialized:
        return False

    if op.mainloop_schedule not in [
            KernelScheduleType.TmaWarpSpecializedCooperative,
            KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
    ]:
        return False

    return True


def is_op_valid(op):
    if op.gemm_kind == GemmKind.Gemm:
        return is_gemm_op_valid(op)
    if op.gemm_kind == GemmKind.Grouped:
        return is_grouped_gemm_op_valid(op)


################################################################################
def generate_sm90_mixed_gemm_operations():
    arch = 90

    # For legacy reasons, we use unsigned types for the weights. The instanitated template
    # will remap those back to the signed type.
    # Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type)
    supported_dtypes = [
        (DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
        (DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
        (DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16,
         DataType.bf16),
        (DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16),
        (DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16,
         DataType.bf16)
    ]

    quant_ops = [
        TrtLlm_QuantOp.per_column_scale_only,
        TrtLlm_QuantOp.finegrained_scale_only,
        TrtLlm_QuantOp.finegrained_scale_and_zeros
    ]

    epi_tags = [TrtLlm_EpilogueTag.epilogue_op_bias]

    M_TILES = [64, 128]
    N_TILES = [16, 32, 64, 128, 256]
    cta_shapes_mn = product(M_TILES, N_TILES)

    warp_shape = [4, 1, 1]
    stages = 0  # auto

    cga_shapes = product([1, 2], [1, 2], [1])

    partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn,
                           cga_shapes)

    operations = list()
    for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args:
        max_k_bits = 128 * 8
        cta_shape_k = max_k_bits // DataTypeSize[dtype_combo[0]]
        cta_shape_mnk = cta_shape_mn + (cta_shape_k, )

        use_coop = cta_shape_mn[0] == 128
        mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative if use_coop else KernelScheduleType.TmaWarpSpecializedPingpong
        epi_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative if use_coop else EpilogueScheduleType.TmaWarpSpecialized

        fpA_intB_operation = TrtLlm_GemmLauncher(GemmKind.Gemm, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
                                                 warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)

        if is_op_valid(fpA_intB_operation):
            operations.append(fpA_intB_operation)

    return operations


def generate_sm90_grouped_gemm_operations():
    arch = 90
    supported_dtypes = [
        DataType.f16, DataType.bf16, DataType.f32, DataType.e4m3
    ]
    quant_ops = [TrtLlm_QuantOp.none]
    epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
    M_TILES = [128]  # Currently M tile must be 128 for Grouped GEMM
    N_TILES = [16, 32, 64, 128, 256]
    cta_shapes_mn = list(product(M_TILES, N_TILES)) + [(256, 128)]

    warp_shape = [0, 0, 0]  # ignored except for naming
    stages = 0  # auto

    epi_fusions = [
        TrtLlm_EpilogueFusion.epilogue_fusion_none,
        TrtLlm_EpilogueFusion.epilogue_fusion_finalize
    ]

    cga_shapes = product([1, 2], [1, 2], [1])

    partial_args = product(supported_dtypes, quant_ops, epi_tags, epi_fusions,
                           cta_shapes_mn, cga_shapes)

    operations = list()
    for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args:
        max_k_bits = 128 * 8
        cta_shape_k = max_k_bits // DataTypeSize[dtype]
        cta_shape_mnk = cta_shape_mn + (cta_shape_k, )

        mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative if dtype != DataType.e4m3 else KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
        epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized

        otypes = [dtype]
        if dtype == DataType.e4m3:
            otypes = [DataType.f16, DataType.bf16]

        for otype in otypes:
            moe_gemm_operation = TrtLlm_GemmLauncher(
                GemmKind.Grouped, arch, dtype, dtype, dtype, dtype, otype,
                quant_op, epi_tag, cta_shape_mnk, warp_shape, stages, cga_shape,
                mainloop_schedule, epi_schedule, epi_fusion)

            if is_op_valid(moe_gemm_operation):
                operations.append(moe_gemm_operation)
    return operations



class GemmSm80LauncherConfig:

    def __init__(self, gemm_kind, arch, dtype, epi_tag, cta_shape, stage):
        self.gemm_kind = gemm_kind
        self.arch = arch
        self.dtype = dtype
        self.epi_tag = epi_tag
        self.cta_shape = cta_shape
        self.stage = stage


def generate_sm80_fused_grouped_gemm_operations():
    arch = 80
    supported_dtypes = [DataType.f16, DataType.bf16]
    epi_tags = [
        TrtLlm_EpilogueTag.epilogue_op_silu, TrtLlm_EpilogueTag.epilogue_op_gelu
    ]
    cta_shapes_mnk = [(16, 128, 64), (16, 256, 64), (32, 128, 64),
                      (64, 128, 64), (128, 128, 64)]

    stages = [2, 3, 4]

    partial_args = product(supported_dtypes, epi_tags, cta_shapes_mnk, stages)

    operations = list()
    for dtype, epi_tag, cta_shape_mnk, stage in partial_args:
        item = GemmSm80LauncherConfig(GemmKind.Grouped, arch, dtype, epi_tag,
                                      cta_shape_mnk, stage)
        operations.append(item)
    return operations


if __name__ == "__main__":
    fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
    moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl"
    sm80_moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl"

    inl_map = {
        (GemmKind.Gemm, 90): [fpA_intB_inl],
        (GemmKind.Grouped, 90): [moe_gemm_inl],
        (GemmKind.Grouped, 80): [sm80_moe_gemm_inl]
    }

    # The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads.
    # Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve.

    operations = []
    # operations.extend(generate_sm90_mixed_gemm_operations())
    operations.extend(generate_sm90_grouped_gemm_operations())
    operations.extend(generate_sm80_fused_grouped_gemm_operations())

    op_groups = dict()
    for op in operations:
        dict_key = (op.gemm_kind, op.arch, op.cta_shape[0])
        op_group = op_groups.get(dict_key, list())
        op_group.append(op)
        op_groups[dict_key] = op_group

    for key, ops in op_groups.items():
        for op in ops:
            print(instantiate_operation(op))
