inline __device__ void gemm_cl()

in candle-flash-attn-v1/kernels/fmha/gemm.h [276:368]


inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
    using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>;
#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800
    using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
#elif defined(__CUDA_ARCH__)  && __CUDA_ARCH__ >= 750
    using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
#else
    using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
    // TD [2022-06-02] We don't support Volta (SM70) yet.
    assert(0);
#endif
    using Element = typename HalfTypeToCutlassType<elem_type>::Type;
    using ElementC = float;
    using LayoutA = cutlass::layout::RowMajor;
    using LayoutB = cutlass::layout::ColumnMajor;

    using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp<
        Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
        cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type;

    constexpr int kIters = Shape::kK / InstructionShape::kK;
    // using FragmentA = typename WarpMma::FragmentA;
    // using FragmentB = typename WarpMma::FragmentB;
    using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA;
    using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB;
    using FragmentC = typename WarpMma::FragmentC;

    // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) {
    //     printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements);
    //     printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements);
    //     printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements);
    //     printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements);
    //     printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements);
    //     printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements);
    // }

    // static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
    // static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
    static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS);
    static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS);
    static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS);
    // const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
    // const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
    FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc);
    FragmentA a_cl[kIters][M];
    FragmentA b_cl[kIters][N];
    constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2;
    #pragma unroll
    for (int iter = 0; iter < kIters; iter++) {
        #pragma unroll
        for (int mi = 0; mi < M; mi++) {
            uint32_t *a_ptr = a_cl[iter][mi].raw_data();
            #pragma unroll
            for (int ki = 0; ki < kRegs; ki++) {
                a_ptr[ki] = a[mi].regs_[iter * kRegs + ki];
            }
        }
    }
    #pragma unroll
    for (int iter = 0; iter < kIters; iter++) {
        #pragma unroll
        for (int ni = 0; ni < N; ni++) {
            uint32_t *b_ptr = b_cl[iter][ni].raw_data();
            #pragma unroll
            for (int ki = 0; ki < kRegs; ki++) {
                // b_ptr[ki] = b[ni].regs_[iter * kRegs + ki];
                // TD [2022-06-02] For some reason the order for frag_b is different.
                b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter];
            }
        }
    }

    WarpMma mma_op;
    // mma_op(c_cl, a_cl, b_cl, c_cl);
    #pragma unroll
    for (int iter = 0; iter < kIters; iter++) {
        mma_op(c_cl, reinterpret_cast<const typename WarpMma::FragmentA (&)>(a_cl[iter]),
               reinterpret_cast<const typename WarpMma::FragmentB (&)>(b_cl[iter]), c_cl);
    }

    // The modified c_cl is not copied back into acc, idk why
    #pragma unroll
    for (int mi = 0; mi < M; mi++) {
        #pragma unroll
        for (int ni = 0; ni < N; ni++) {
            #pragma unroll
            for (int i =0; i < 8; i++) {
                acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i];
            }
        }
    }

}