in kernels/fmha/gemm.h [276:368]
inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
#else
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
// TD [2022-06-02] We don't support Volta (SM70) yet.
assert(0);
#endif
using Element = typename HalfTypeToCutlassType<elem_type>::Type;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type;
constexpr int kIters = Shape::kK / InstructionShape::kK;
// using FragmentA = typename WarpMma::FragmentA;
// using FragmentB = typename WarpMma::FragmentB;
using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA;
using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB;
using FragmentC = typename WarpMma::FragmentC;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) {
// printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements);
// printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements);
// printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements);
// printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements);
// printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements);
// printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements);
// }
// static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
// static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS);
static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS);
static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS);
// const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
// const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc);
FragmentA a_cl[kIters][M];
FragmentA b_cl[kIters][N];
constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2;
#pragma unroll
for (int iter = 0; iter < kIters; iter++) {
#pragma unroll
for (int mi = 0; mi < M; mi++) {
uint32_t *a_ptr = a_cl[iter][mi].raw_data();
#pragma unroll
for (int ki = 0; ki < kRegs; ki++) {
a_ptr[ki] = a[mi].regs_[iter * kRegs + ki];
}
}
}
#pragma unroll
for (int iter = 0; iter < kIters; iter++) {
#pragma unroll
for (int ni = 0; ni < N; ni++) {
uint32_t *b_ptr = b_cl[iter][ni].raw_data();
#pragma unroll
for (int ki = 0; ki < kRegs; ki++) {
// b_ptr[ki] = b[ni].regs_[iter * kRegs + ki];
// TD [2022-06-02] For some reason the order for frag_b is different.
b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter];
}
}
}
WarpMma mma_op;
// mma_op(c_cl, a_cl, b_cl, c_cl);
#pragma unroll
for (int iter = 0; iter < kIters; iter++) {
mma_op(c_cl, reinterpret_cast<const typename WarpMma::FragmentA (&)>(a_cl[iter]),
reinterpret_cast<const typename WarpMma::FragmentB (&)>(b_cl[iter]), c_cl);
}
// The modified c_cl is not copied back into acc, idk why
#pragma unroll
for (int mi = 0; mi < M; mi++) {
#pragma unroll
for (int ni = 0; ni < N; ni++) {
#pragma unroll
for (int i =0; i < 8; i++) {
acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i];
}
}
}
}