candle-flash-attn-v1/kernels/fmha/gmem_tile.h (326 lines of code) (raw):
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "utils.h"
namespace fmha {
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile_,
// The number of bits per element.
int BITS_PER_ELEMENT,
// The number of rows of Q, K or V loaded by this tile.
int ROWS_,
// The number of columns.
int COLS,
int BYTES_PER_LDGS_ = 16
>
struct Gmem_tile_qkv {
using Cta_tile = Cta_tile_;
static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8;
// The size of each LDG.
static constexpr int BYTES_PER_LDG = BYTES_PER_LDGS_;
// The size of a row in bytes.
static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8;
// The number of threads to load a "row" of the matrix.
static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG;
static constexpr int ROWS = ROWS_;
// The number of "rows" loaded per LDG.
static constexpr int ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
// The number of LDGs needed to load a chunk of the Q matrix.
static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG);
// Ctor.
template< typename BInfo >
inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts,
const uint32_t head_stride_in_elts, const int headdim,
const BInfo &binfo, const int tidx, bool use_seqlen_q)
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k)
, ptr(reinterpret_cast<char *>(ptr_))
, tidx_(tidx)
, col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_LDG / BYTES_PER_ELEMENT) < headdim) {
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % THREADS_PER_ROW;
// Store the row as we need it to disable the loads.
// TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of storing it
// row_ = row;
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes);
// Add the block index.
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
// Assemble the final pointer.
ptr += row_offset + col * BYTES_PER_LDG;
}
// Store data to shared memory.
template< typename Smem_tile >
inline __device__ void commit(Smem_tile &smem_tile) {
smem_tile.store(fetch_);
}
inline __device__ void load() {
int row_ = tidx_ / THREADS_PER_ROW;
const void *ptrs[LDGS];
uint32_t preds[LDGS];
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
preds[ii] = col_predicate && ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
fetch_[ii] = make_uint4(0, 0, 0, 0);
}
// not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
fct.load(ii, preds[ii]);
}
}
// Store data to memory.
inline __device__ void store(const uint4 (&data)[LDGS]) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
fmha::stg(ptr_, data[ii]);
}
}
}
inline __device__ void move(const int steps = 1) {
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
actual_seqlen -= ROWS * steps;
}
// The stride between rows for the QKV matrice.
// int64_t row_stride_in_bytes;
const uint32_t row_stride_in_bytes;
// The pointer.
char *ptr;
// The fetch registers.
uint4 fetch_[LDGS];
// Keep track of the row the thread is processing as we move the tile.
// int row_;
const int tidx_;
// The length of the sequence loaded by that memory tile.
int actual_seqlen;
const bool col_predicate;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Cta_tile,
int BYTES_PER_ELEMENT = 2
>
struct Gmem_tile_o {
static_assert(BYTES_PER_ELEMENT == 2 || BYTES_PER_ELEMENT == 4);
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The size of each element.
// static constexpr int BYTES_PER_ELEMENT = 2;
// The size of each STG.
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 4;
static constexpr int COLS = Cta_tile::N;
// The size of a row in bytes.
static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT;
// The number of threads to store a "row" of the matrix.
static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG;
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
static constexpr int ROWS = Cta_tile::M;
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA;
// The number of outter loop for the stores.
static constexpr int LOOPS = ROWS / ROWS_PER_LOOP;
// The number of "rows" stored per STG.
static constexpr int ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
// Do we have to guard against partial writes/reads.
static constexpr bool HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0;
// The number of STGs needed to store a chunk of the Q matrix.
static constexpr int STGS_PER_LOOP = DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_STG);
// The number of STGs needed to store a chunk of the Q matrix in total.
static constexpr int STGS = STGS_PER_LOOP * LOOPS;
// Ctor.
template<typename BInfo>
// inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx)
inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts,
const uint32_t head_stride_in_elts, const int headdim,
const BInfo &binfo, const int tidx)
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen_q(binfo.actual_seqlen_q)
, ptr_(reinterpret_cast<char *>(ptr))
, tidx_(tidx)
, col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_STG / BYTES_PER_ELEMENT) < headdim) {
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % THREADS_PER_ROW;
// Store the row as we need it to disable loads.
// row_ = row;
// The row offset in the batched GEMM.
// int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)((binfo.sum_s_q + row) * row_stride_in_bytes);
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
// Assemble the final pointer.
ptr_ += row_offset + col * BYTES_PER_STG;
// Is that thread active on the last STG?
if( HAS_INCOMPLETE_STG ) {
is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;
}
}
// Store data to global memory.
template<typename elem_type=__half>
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
break;
}
if (BYTES_PER_ELEMENT == 4) {
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, src[ii]);
}
} else if (BYTES_PER_ELEMENT == 2) {
float x = reinterpret_cast<const float &>(src[ii].x);
float y = reinterpret_cast<const float &>(src[ii].y);
float z = reinterpret_cast<const float &>(src[ii].z);
float w = reinterpret_cast<const float &>(src[ii].w);
uint2 out = fmha::float4_pack<elem_type>(x, y, z, w);
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out);
}
}
}
}
// Store data to global memory with atomicAdd.
inline __device__ void atomic_add(const uint4 (&src)[STGS_PER_LOOP], int mi) {
static_assert(BYTES_PER_ELEMENT == 4); // Only do atomic add on floats
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
break;
}
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
float *ptr_ = reinterpret_cast<float *>(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
atomicAdd(ptr_ + jj, reinterpret_cast<const float(&)[4]>(src[ii])[jj]);
}
}
}
}
// Load data from global memory.
inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) {
static_assert(BYTES_PER_ELEMENT == 4);
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
break;
}
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
}
}
}
inline __device__ void move(const int steps = 1) {
// row_ += ROWS * steps;
// ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps;
actual_seqlen_q -= ROWS * steps;
}
// The stride between rows for the QKV matrice.
// int64_t row_stride_in_bytes;
const uint32_t row_stride_in_bytes;
// The pointer.
char *ptr_;
// Is the thread active for the last STG?
int is_active_for_last_stg_;
// The length of the sequence loaded by that memory tile.
int actual_seqlen_q;
const int tidx_;
const bool col_predicate;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, int BYTES_PER_ELEMENT >
struct Gmem_tile_mma_sd {
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// Each STG stores 8 elements.
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8;
// The number of MMAs in the M dimension.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
// The number of MMAs in the N dimension.
static constexpr int MMAS_N = Mma_tile::MMAS_N;
// The number of rows computed per MMA per thread block.
static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA;
// The number of cols computed per MMA per thread block.
static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA;
// The number of threads per block.
static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA;
// The size of each row in bytes. I.e. how many bytes are stored per STG.
static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG;
// The distance between elements stored per loop (in bytes).
static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW;
// The type of elements stored per STG.
using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;
// Ctor.
template<typename Params>
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx)
: ptr_(static_cast<char *>(ptr)) {
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t bidx = bidb * params.h + bidh;
// The distance between two blocks (in bytes).
// const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
const uint32_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
// Set store location for each thread at the beginning of the loop
ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG;
}
// Store to global memory.
inline __device__ void store(const Type &data, const int mi, const int ni) {
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::stg(ptr_ + offset, data);
}
// Load from global memory.
inline __device__ void load(Type &data, const int mi, const int ni) {
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::ldg(data, ptr_ + offset);
}
// Move to the next tile.
inline __device__ void move(const int steps = 1) {
ptr_ += LOOP_STRIDE_BYTES * steps;
}
// The pointer in global memory.
char *ptr_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >
struct Gmem_tile_mma_s : public Base {
// The number of mmas in the vertical dimension.
static constexpr int M = Base::MMAS_M;
// The number of mmas in the horizontal dimension.
static constexpr int N = Base::MMAS_N;
// The type of the vectors stored by each STG.
using Type = typename Base::Type;
// Ctor.
template< typename Params, typename Block_info >
inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx)
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
}
// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 dst;
dst.x = frag[ni][mi].reg(0);
dst.y = frag[ni][mi].reg(2);
dst.z = frag[ni][mi].reg(1);
dst.w = frag[ni][mi].reg(3);
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Load from global memory.
template<typename Mask>
inline __device__ void load(uint4 (®s)[M][N], const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
regs[mi][ni] = make_uint4(0, 0, 0, 0);
if( mask.any_valid(mi, ni) ) {
Base::load(regs[mi][ni], mi, ni);
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile
>
struct Gmem_summary_stats {
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
// The size of each element.
static constexpr int BYTES_PER_ELEMENT = 4;
static constexpr int BYTES_PER_MMA = (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT;
static constexpr int ROWS = Cta_tile::M;
// Ctor.
template<typename Params>
inline __device__ Gmem_summary_stats(void *ptr, const Params ¶ms, const int tidx)
: ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) {
// The block index for the batch.
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.y;
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t bidx = bidb * params.h + bidh;
// Extract the position in the warp.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// The distance between two blocks (in bytes).
// size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
uint32_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
// Set store location for each thread at the beginning of the loop
ptr_row_ = ptr_ + bidx * block_stride_bytes;
ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT;
}
// Store data to global memory.
inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) {
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
if ((warp == 0) && (lane % 4 == 0)) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, data[mi * 2 + 0]);
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, data[mi * 2 + 1]);
}
}
}
// Store data to global memory.
inline __device__ void store_row(const uint32_t (&data)[MMAS_M], const int row) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::stg(ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]);
}
}
// Load from global memory.
inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::ldg(data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
fmha::ldg(data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
}
}
// Load from global memory.
inline __device__ void load_next(uint32_t (&data)[MMAS_M * 2], int move_steps=1) {
char *ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::ldg(data[mi * 2 + 0], ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
fmha::ldg(data[mi * 2 + 1], ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
}
}
// Store data to global memory.
template <int N>
inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) {
#pragma unroll
for (int ni = 0; ni < N; ++ni) {
fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT);
}
}
// Move the pointer to the next location.
inline __device__ void move() {
ptr_ += ROWS * BYTES_PER_ELEMENT;
ptr_row_ += ROWS * BYTES_PER_ELEMENT;
}
// Move the pointer to the next location.
inline __device__ void move(const int steps) {
ptr_ += ROWS * BYTES_PER_ELEMENT * steps;
ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps;
}
// The pointer.
char *ptr_;
char *ptr_row_;
const int tidx_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha