candle-flash-attn-v1/kernels/fmha/gemm.h (269 lines of code) (raw):

/****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include "utils.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/warp/default_mma_tensor_op.h" #include "cutlass/layout/layout.h" #include "cutlass/arch/mma.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ > struct Fragment_base_ { // The data type. using Data_type = Data_type_; // default input type using Input_type_ = Data_type_; // Does it store the array of elements. static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8; // The number of elements. static constexpr int NUM_ELTS = NUM_ELTS_; // The size of element in bits. static constexpr int BITS_PER_ELT = BITS_PER_ELT_; // The size of byte of a single register. static constexpr int BYTES_PER_REG = 4; // The size in bits. static constexpr int BITS_PER_REG = BYTES_PER_REG * 8; // The number of registers needed to store the fragment. static constexpr int NUM_REGS = DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG); // The size in bytes (as returned by sizeof(Fragment_base<>). static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG; // The alignment. static constexpr int ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : MinConstexpr(NUM_REGS * BYTES_PER_REG, 16); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< // The type of the elements. typename Data_type_, // The number of elements. int NUM_ELTS_, // The alignment if you want to force a value -- use 0 otherwise. int ALIGNMENT_ = 0, // The base class. typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_> > struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ { // The size of a load/store. static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t); // Clear the fragment. Using PTX in that code seems to produce better SASS... inline __device__ void clear() { #pragma unroll for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : ); } } // Immutable access to a register. inline __device__ const uint32_t& reg(int ii) const { return this->regs_[ii]; } // Mutable access to a register. inline __device__ uint32_t& reg(int ii) { return this->regs_[ii]; } uint32_t regs_[Base_::NUM_REGS]; // Immutable access to the elements. inline __device__ const Data_type_& elt(int ii) const { return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii]; } // Mutable access to the elements. inline __device__ Data_type_& elt(int ii) { return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii]; } // Immutable access to the elements with a cast. template< typename Cast_type > inline __device__ const Cast_type& elt_as(int ii) const { return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii]; } // Mutable access to the elements. template< typename Cast_type > inline __device__ Cast_type& elt_as(int ii) { return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii]; } // Add another fragment. inline __device__ void add(const Fragment &other) { // TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS? // Also are we doing int addition or __half2 addition? #pragma unroll for( int ii = 0; ii < NUM_ELTS_; ++ii ) { this->elt(ii) += other.elt(ii); } } // Multiply by another fragment. inline __device__ void hmul(const Fragment &other) { #pragma unroll for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii)); } } template <typename elem_type> inline __device__ void hrelu_() { #pragma unroll for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { this->reg(ii) = fmha::hrelu2<elem_type>(this->reg(ii)); } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Layout > struct Fragment_a : public Fragment<uint16_t, 8> { }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Layout > struct Fragment_b : public Fragment<uint16_t, 8> { }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Fragment_accumulator : public Fragment<float, 8> { // The base class. using Base = Fragment<float, 8>; // Add two fragments. template< typename Other_fragment_ > inline __device__ void add(const Other_fragment_ &other) { for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { this->elt(ii) = this->elt(ii) + other.elt(ii); } } inline __device__ void mul_(const float other) { for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { this->elt(ii) *= other; } } // Do the HMMA. template< typename Layout_a, typename Layout_b > inline __device__ void mma(const Fragment_a<Layout_a> &a, const Fragment_b<Layout_b> &b) { asm volatile( \ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ " {%0, %1, %2, %3}, \n" \ " {%4, %5, %6, %7}, \n" \ " {%8, %9}, \n" \ " {%0, %1, %2, %3}; \n" \ : "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3)) : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) , "r"(b.reg(0)), "r"(b.reg(1))); asm volatile( \ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ " {%0, %1, %2, %3}, \n" \ " {%4, %5, %6, %7}, \n" \ " {%8, %9}, \n" \ " {%0, %1, %2, %3}; \n" \ : "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7)) : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) , "r"(b.reg(2)), "r"(b.reg(3))); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Fragment, int M, int N > inline __device__ void clear(Fragment (&frag)[M][N]) { #pragma unroll for( int mi = 0; mi < M; ++mi ) { #pragma unroll for( int ni = 0; ni < N; ++ni ) { frag[mi][ni].clear(); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename Accumulator_type, int WARPS_K > struct Clear_accumulator { }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< int WARPS_K > struct Clear_accumulator<float, WARPS_K> { template< typename Acc, int M, int N > static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { fmha::clear(acc); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename Acc, typename A, typename B, int M, int N> inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { #pragma unroll for( int mi = 0; mi < M; ++mi ) { #pragma unroll for( int ni = 0; ni < N; ++ni ) { acc[mi][ni].mma(a[mi], b[ni]); } } } //////////////////////////////////////////////////////////////////////////////////////////////// /// Statically maps half types => cutlass data types ///////////////////////////////////////////////////////////////////////////////////////////////// template <typename Type_> struct HalfTypeToCutlassType { using Type = Type_; }; /// Statically maps __half => cutlass::half_t template <> struct HalfTypeToCutlassType<__half> { using Type = cutlass::half_t; }; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) template <> struct HalfTypeToCutlassType<__nv_bfloat16> { using Type = cutlass::bfloat16_t; }; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename elem_type, typename Acc, typename A, typename B, int M, int N> inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; #elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; #else using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; // TD [2022-06-02] We don't support Volta (SM70) yet. assert(0); #endif using Element = typename HalfTypeToCutlassType<elem_type>::Type; using ElementC = float; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type; constexpr int kIters = Shape::kK / InstructionShape::kK; // using FragmentA = typename WarpMma::FragmentA; // using FragmentB = typename WarpMma::FragmentB; using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA; using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB; using FragmentC = typename WarpMma::FragmentC; // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) { // printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements); // printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements); // printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements); // printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements); // printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements); // printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements); // } // static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS); // static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS); static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS); static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS); static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS); // const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a); // const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b); FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc); FragmentA a_cl[kIters][M]; FragmentA b_cl[kIters][N]; constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2; #pragma unroll for (int iter = 0; iter < kIters; iter++) { #pragma unroll for (int mi = 0; mi < M; mi++) { uint32_t *a_ptr = a_cl[iter][mi].raw_data(); #pragma unroll for (int ki = 0; ki < kRegs; ki++) { a_ptr[ki] = a[mi].regs_[iter * kRegs + ki]; } } } #pragma unroll for (int iter = 0; iter < kIters; iter++) { #pragma unroll for (int ni = 0; ni < N; ni++) { uint32_t *b_ptr = b_cl[iter][ni].raw_data(); #pragma unroll for (int ki = 0; ki < kRegs; ki++) { // b_ptr[ki] = b[ni].regs_[iter * kRegs + ki]; // TD [2022-06-02] For some reason the order for frag_b is different. b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter]; } } } WarpMma mma_op; // mma_op(c_cl, a_cl, b_cl, c_cl); #pragma unroll for (int iter = 0; iter < kIters; iter++) { mma_op(c_cl, reinterpret_cast<const typename WarpMma::FragmentA (&)>(a_cl[iter]), reinterpret_cast<const typename WarpMma::FragmentB (&)>(b_cl[iter]), c_cl); } // The modified c_cl is not copied back into acc, idk why #pragma unroll for (int mi = 0; mi < M; mi++) { #pragma unroll for (int ni = 0; ni < N; ni++) { #pragma unroll for (int i =0; i < 8; i++) { acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i]; } } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template< // The number of rows in the CTA tile. int M_, // The number of cols in the CTA tile. int N_, // The number of elements in the the K dimension of the GEMM loop. int K_, // The number of rows of warps. int WARPS_M_, // The number of cols of warps. int WARPS_N_, // The number of warps in the K dimension of the GEMM loop. int WARPS_K_> struct Cta_tile_ { static constexpr int M = M_, N = N_, K = K_; // The number of warps. static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_; // The number of warps per CTA. static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K; // The number of threads per warp. static constexpr int THREADS_PER_WARP = 32; // The number of threads per CTA. static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename Cta_tile> struct Hmma_tile { // The number of elements computed with a single warp-MMA. static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16; // The number of elements computed with a single CTA-MMA. static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K; // The number of MMAs needed to compute the GEMM. static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA), MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA), MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA); // // The number of elements computed per warp. // static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA, // N_PER_WARP = MMAS_N * N_PER_MMA, // K_PER_WARP = MMAS_K * K_PER_MMA; }; //////////////////////////////////////////////////////////////////////////////////////////////////// using A_type = uint16_t; using B_type = uint16_t; using C_type = uint16_t; using Accumulator_type = float; using Epilogue_type = float; constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8; constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8; constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8; //////////////////////////////////////////////////////////////////////////////////////////////////// template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K> using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>; //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename Cta_tile_> using Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M, Cta_tile_::N, Next_power_of_two<Cta_tile_::K>::VALUE, Cta_tile_::WARPS_M, Cta_tile_::WARPS_N, Cta_tile_::WARPS_K>; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha