candle-flash-attn-v1/kernels/fmha_fprop_kernel_1xN.h (373 lines of code) (raw):
/***************************************************************************************************
* Copyright (c) 2022, Tri Dao.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include "fmha/kernel_traits.h"
#include "fmha/gemm.h"
#include "fmha/utils.h"
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits>
struct Gemm_Q_K_base {
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
using Fragment_q = typename Smem_tile_q::Fragment;
using Fragment_k = typename Smem_tile_k::Fragment;
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2;
__device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx)
: smem_q(smem_ptr_q, tidx)
, smem_k(smem_ptr_k, tidx) {
}
__device__ inline void load_q() {
smem_q.load(frag_q[0], 0);
}
__device__ inline void reload_q() {
smem_q.load(frag_q[0], 0);
}
Fragment_q frag_q[2][Mma_tile_p::MMAS_M];
Smem_tile_q smem_q;
Smem_tile_k smem_k;
};
template<typename Kernel_traits, bool K_in_regs, typename elem_type_=__half>
struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
using Smem_tile_o = typename Base::Smem_tile_o;
using Smem_tile_q = typename Base::Smem_tile_q;
using Smem_tile_k = typename Base::Smem_tile_k;
using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p;
using elem_type = elem_type_;
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
// If V is stored in shared memory, we can't load K using the same shared memory.
static_assert(Kernel_traits::V_IN_REGS);
static constexpr int SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE;
static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE;
static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE);
// Q | K / V
// | O | SOFTMAX
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
+ std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE,
Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX);
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
}
__device__ inline void load_k(){
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
Base::smem_k.load(frag_k[ki], ki);
}
}
template<typename Acc, int M, int N>
__device__ inline void operator()(Acc (&acc_p)[M][N]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
}
__device__ inline void reload_k(){
// Noop.
}
Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
};
template<typename Kernel_traits, typename elem_type_>
struct Gemm_Q_K<Kernel_traits, false, elem_type_> : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
using Smem_tile_o = typename Base::Smem_tile_o;
using Smem_tile_q = typename Base::Smem_tile_q;
using Smem_tile_k = typename Base::Smem_tile_k;
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p;
using elem_type = elem_type_;
Fragment_k frag_k[2][Mma_tile_p::MMAS_N];
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS;
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V);
static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE);
static_assert(Smem_tile_v::BYTES_PER_TILE == (int) Smem_tile_k::BYTES_PER_TILE);
static constexpr int SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE;
static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE;
// If V_IN_REGS and SHARE_SMEM_FOR_K_AND_V: Q | K/V | O | SOFTMAX
// If !V_IN_REGS (then !SHARE_SMEM_FOR_K_AND_V): Q | K | V | O | SOFTMAX
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
+ (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE
+ Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX;
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
}
__device__ inline void load_k(){
Base::smem_k.load(frag_k[0], 0);
}
template<typename Acc, int M, int N>
__device__ inline void operator()(Acc (&acc_p)[M][N]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki);
Base::smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
}
__device__ inline void reload_k(){
Base::smem_k.load(frag_k[0], 0);
}
};
template<typename Kernel_traits>
constexpr size_t get_dynamic_smem_size(){
return Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>::SMEM_BYTES;
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, Prng &ph, const int loop_step_idx) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type;
#else
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
assert(is_fp16_type);
using elem_type = __half;
#endif
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
using Gmem_tile_o_tmp = fmha::Gmem_tile_o<Cta_tile_o, 4>;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum;
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS, elem_type>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
// How many steps to jump per iteration, which is the same as params.num_splits.
const int step_stride = gridDim.z;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
// if( binfo.stop_early() ) return;
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
params.d, binfo, tidx, true);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_tmp_row_stride_in_elts,
params.o_tmp_head_stride_in_elts, params.d, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
// Wind gmem tiles to the correct position.
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
// We want begin to be a multiple of gridDim.z
// This is because the row indices processed by each threadblock must align between the
// loop steps, otherwise we have a dependency between the blocks.
// For example, threadblock with blockIdx.z == 1 must process row indices that are
// k * gridDim.z + 1 for integer k.
const int begin_mod_z = begin % gridDim.z;
begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z;
// Otherwise we'd be reading out-of-bound memory before the loop
if ((begin + blockIdx.z) * Cta_tile_p::M >= binfo.actual_seqlen_q) return;
const int steps_og = steps;
steps -= begin;
gmem_q.move(begin + blockIdx.z);
gmem_o.move(begin + blockIdx.z);
gmem_o_tmp.move(begin + blockIdx.z);
if (Return_softmax) {
gmem_s.move(begin + blockIdx.z);
}
gmem_softmax_lse.move(begin + blockIdx.z);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("begin = %d, steps = %d\n", begin, steps);
// }
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
params.d, binfo, tidx, false);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
params.d, binfo, tidx, false);
// The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);
if (!Is_first) {
gmem_k.move(loop_step_idx);
gmem_v.move(loop_step_idx);
if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); }
}
// Trigger the loads for K.
gmem_k.load();
// Trigger the loads for Q.
gmem_q.load();
// Trigger the loads for V.
gmem_v.load();
if (!Is_first) { __syncthreads(); }
float p_prev_lse[Mma_tile_p::MMAS_M * 2];
if (!Is_first) {
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
}
// Commit the data for Q and V to shared memory.
gmem_q.commit(gemm_q_k.smem_q);
gmem_v.commit(smem_v);
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
// #pragma unroll
// for(int it=0;it < Gmem_tile_k::LDGS;it++){
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
// }
// Commit the data for K to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_k.commit(gemm_q_k.smem_k);
}
__syncthreads();
// Load the fragments for Q.
gemm_q_k.load_q();
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_k.commit(gemm_q_k.smem_k);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for K.
gemm_q_k.load_k();
// Create the object to do the softmax.
Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx);
Smem_softmax_sum smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]), tidx);
// Load over the entire sequence length.
for (int l = blockIdx.z; l < steps; l += step_stride) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z <= 1)) {
// printf("l = %d\n", l);
// }
if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P = Q * K^T.
gemm_q_k(acc_p);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1));
// }
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
if (!Is_first) { gmem_o_tmp.load(out, 0); }
// Trigger the load for the next Q values.
if (l + step_stride < steps) {
gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move(step_stride);
gmem_q.load();
}
// Load the mask for that iteration.
mask.load(begin + l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack_noscale(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l < step_stride ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l >= 0)) {
// printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
// }
// }
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
if (!Is_first) {
smem_softmax_lse.store_pair(p_prev_lse);
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; }
}
// Trigger the load for the next LSE values.
if (l + step_stride < steps) {
if (!Is_first) {
gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse),
step_stride);
}
}
softmax.template reduce_max</*zero_init=*/Is_first>(p_max);
// if ((threadIdx.x == 0) && (l == 38)) {
// printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]);
// }
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
// }
// }
// Compute the exponential value.
// softmax.apply_exp(p_max);
softmax.scale_apply_exp(p_max, params.scale_bmm1f);
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
// }
// }
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
// if (!Is_first) {
// int warp = tidx / Cta_tile_p::THREADS_PER_WARP;
// int lane = tidx % Cta_tile_p::THREADS_PER_WARP;
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
// p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0;
// }
// }
// softmax.reduce_sum(p_sum);
softmax.reduce_sum_before_sync_(p_sum);
// softmax.template reduce_sum_before_sync_</*zero_init=*/Is_first>(p_sum);
// float p_sum_log[Mma_tile_p::MMAS_M * 2];
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) {
// float sum = p_sum[mi];
// // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum);
// constexpr float kLog2e = M_LOG2E;
// p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum);
// }
// // gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum));
// gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum_log));
// gmem_softmax_lse.move();
// // Finalize softmax on the accumulators of P^T.
// softmax.scale(p_sum);
constexpr bool encode_dropout_in_sign_bit = Return_softmax;
if (Is_dropout) {
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint);
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint);
// softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint16_t);
unsigned int warp_idx = threadIdx.x / 32;
// TODO: this should change after we rearrange the warps (e.g. cutlass branch)
unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx;
// We want to use actual_seqlen_k, not seqlen_k, since seqlen_k could be rounded
// differently in the fwd and bwd pass. E.g., for d=128 on A100, fwd rounds seqlen_k
// to multiples of 256 while bwd rounds seqlen_k to multiples of 128.
unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx;
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint16_t, philox_subsequence);
}
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
softmax.template pack<elem_type>(frag_p);
if (Return_softmax) {
gmem_s.store(frag_p, mask);
gmem_s.move(step_stride);
}
// Commit the values for Q into shared memory.
if (l + step_stride < steps) {
gmem_q.commit(gemm_q_k.smem_q);
}
if (Is_dropout && encode_dropout_in_sign_bit) {
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
frag_p[ki][mi].template hrelu_<elem_type>();
}
}
}
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm_cl<elem_type>(acc_o, frag_p[ki], frag_v[ki]);
// if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
// float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
// printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, tmp_v.y, acc_o[0][0].elt(0));
// }
}
// if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, acc_o[0][2].elt(0));
// }
// The mapping from tidx to rows changes between the softmax and the
// O-reduction. So we recalculate the max.
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
int rows[Gmem_tile_o::STGS_PER_LOOP];
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
}
softmax.reduce_max_after_sync_(p_max_o, rows);
static_assert(Mma_tile_o::MMAS_M == 1);
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
p_max_o[jj][0] *= params.scale_bmm1f;
}
float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
if (!Is_first) {
smem_softmax_lse.load(p_prev_scale_o, rows);
}
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]);
// }
// }
static_assert(Gmem_tile_o::LOOPS == 1);
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, 0);
// Make sure the data is in shared memory.
__syncthreads();
static_assert(Mma_tile_o::MMAS_M == 1);
float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
softmax.reduce_sum_after_sync_(p_sum_o, rows);
if (!Is_first) {
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
p_sum_o[jj][0] += p_prev_scale_o[jj];
}
}
float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
#pragma unroll
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
float sum = p_sum_o[jj][0];
p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum);
// if (sum == 0.f || sum != sum) {
// printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]);
// }
// if (Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
// }
// }
if (tidx % Gmem_tile_o::THREADS_PER_ROW == 0) {
gmem_softmax_lse.store_row(
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
}
}
gmem_softmax_lse.move(step_stride);
// Load from shared memory.
if (!Is_first) {
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]);
}
}
smem_o.template load</*zero_init=*/Is_first>(out);
const bool is_final_write =
Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
#pragma unroll
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
float sum = p_sum_o[jj][0];
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
if (Is_dropout && is_final_write) {
inv_sum *= params.rp_dropout;
}
out[jj] = fmha::fmul4(out[jj], inv_sum);
}
// if (Is_dropout && Is_last) {
// for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
// out[jj] = fmha::fmul4(out[jj], params.rp_dropout);
// }
// }
// Output the values.
if (is_final_write) {
gmem_o.template store<elem_type>(out, 0);
gmem_o.move(step_stride);
} else {
gmem_o_tmp.store(out, 0);
}
// Move to the next part of the output.
if (!(Is_first && Is_last)) { gmem_o_tmp.move(step_stride); }
gemm_q_k.reload_k();
// Make sure we are reading from the correct buffer.
gemm_q_k.smem_q.move_to_next_read_buffer();
// Trigger the load from shared memory for the next series of Q values.
if (l + step_stride < steps) {
gemm_q_k.reload_q();
}
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, typename Params>
inline __device__ void device_1xN_loop(const Params ¶ms) {
// The block index for the batch.
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.y;
// The block index.
const int bidx = gridDim.x * bidh + bidb;
// The thread index.
const int tidx = threadIdx.x;
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
// them to have the same number of threads or have to traverse the attention matrix
// in the same order.
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
// (within a warp). We use the subsequence to store the location of the 16 x 16 blocks within
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
// auto seeds = at::cuda::philox::unpack(params.philox_args);
// if (bidx == 0 && tidx == 0) {
// params.rng_state[0] = std::get<0>(seeds);
// params.rng_state[1] = std::get<1>(seeds);
// }
Philox ph(0, 0, 0 + (bidb * params.h + bidh) * 32 + tidx % 32);
constexpr int M = Kernel_traits::Cta_tile_p::M;
const int STEPS = (params.seqlen_q + M - 1) / M;
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.seqlen_k == blocksize_c) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, ph, 0);
} else {
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, ph, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, ph, loop_step_idx);
}
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, ph, max_loop_steps - 1);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha