candle-flash-attn-v1/kernels/fmha/smem_tile.h (1,045 lines of code) (raw):
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "utils.h"
#include "utils.h"
#include "gemm.h"
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The description of the tile computed by this CTA.
typename Cta_tile,
// The number of rows in the 2D shared memory buffer.
int M_,
// The number of cols.
int N_,
// The size in bits of each element.
int BITS_PER_ELEMENT_,
// The number of bytes per STS.
int BYTES_PER_STS_ = 16,
// The number of buffers. (Used in multistage and double buffer cases.)
int BUFFERS_PER_TILE_ = 1,
// Do we enable the fast path for LDS.128 and friends.
int ENABLE_LDS_FAST_PATH_ = 0,
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
int ROWS_PER_XOR_PATTERN_ = 8,
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
int COLS_PER_XOR_PATTERN_ = 1,
// Use or not predicates
bool USE_PREDICATES_ = true
>
struct Smem_tile_without_skews {
// The size in bits of each element.
enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ };
// The size in bytes of a single STS.
enum { BYTES_PER_STS = BYTES_PER_STS_ };
// The number of elements per STS.
enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT };
// To support arbitrary N, we pad some values to a power-of-2.
enum { N_WITH_PADDING = Next_power_of_two<N_>::VALUE };
// The number of bytes per row without packing of rows.
enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 };
// The number of bytes per row -- we want at least 128B per row.
enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE };
// The number of rows in shared memory (two rows may be packed into a single one).
enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW };
// The number of threads per row.
enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS };
// The number of threads per row.
enum { THREADS_PER_ROW = Min<Cta_tile::THREADS_PER_CTA, THREADS_PER_ROW_UNBOUNDED>::VALUE };
// The number of STS per row.
enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS };
// It must be at least one.
static_assert(STS_PER_ROW >= 1, "");
// The number of rows written with a single STS.
enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
// Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)
static_assert(ROWS_PER_STS >= 1, "");
// The number of STS needed to store all rows.
enum { STS_PER_COL = Div_up<ROWS, ROWS_PER_STS>::VALUE };
// The number of STS in total.
enum { STS = STS_PER_COL * STS_PER_ROW };
// TD [2022-06-02] In the case of Q (16 x 64) in the backward pass with 256 threads,
// we only need to store 16 * 64 * 2 = 2KB instead of 4KB.
static constexpr bool PARTIAL_STORE = ROWS_PER_STS > ROWS;
static constexpr int STORING_THREADS = PARTIAL_STORE ? ROWS * THREADS_PER_ROW : Cta_tile::THREADS_PER_CTA;
// The size of one buffer in bytes in shared memory.
// enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };
enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * STORING_THREADS };
// The number of buffers.
enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };
// The size in bytes of total buffers.
enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE };
// The boundary for smem_read_offset and smem_write_offset increment.
enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER };
// Do we enable the LDS.128 fast path?
enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ };
static_assert(ENABLE_LDS_FAST_PATH == 0);
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ };
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS };
// Use or not predicates
enum { USE_PREDICATES = USE_PREDICATES_ };
// The type of elements that are stored in shared memory by each thread.
using Store_type = typename Uint_from_size_in_bytes<BYTES_PER_STS>::Type;
// Ctor.
inline __device__ Smem_tile_without_skews(void *smem, int tidx)
: smem_(__nvvm_get_smem_pointer(smem)), tidx_(tidx) {
// The row written by a thread. See doc/mma_smem_layout.xlsx.
int smem_write_row = tidx / THREADS_PER_ROW;
// The XOR pattern.
int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN;
// Compute the column and apply the XOR pattern.
int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor;
// The offset.
this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;
// TODO: Why not merge it with the read offset?
// this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
// this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
}
// Compute the store pointers.
template< int N >
inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) {
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
// Decompose the STS into row/col.
int row = ii / STS_PER_ROW;
int col = ii % STS_PER_ROW;
// Assemble the offset.
int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW;
// Take the column into account.
if( STS_PER_ROW > 1 ) {
offset += col*THREADS_PER_ROW*BYTES_PER_STS;
}
// Apply the XOR pattern if needed.
if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) {
const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;
offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;
}
// Assemble the final pointer :)
// ptrs[ii] = smem_ + offset + smem_write_buffer_;
// smem_write_buffer_ is already merged with smem_write_offset_
ptrs[ii] = smem_ + offset;
}
}
inline __device__ void debug_reset() {
for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
for( int row = 0; row < ROWS; ++row ) {
for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
if( threadIdx.x == 0 ) {
uint32_t val = 0x0;
sts(val, smem_ + row*BYTES_PER_ROW + col + buffer);
}
}
}
}
}
// Print the content of the tile (only for debug ;)).
inline __device__ void debug_print() const {
for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
for( int row = 0; row < ROWS; ++row ) {
for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
if( threadIdx.x == 0 ) {
uint32_t val;
lds(val, smem_ + row*BYTES_PER_ROW + col + buffer);
printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n",
blockIdx.x,
blockIdx.y,
blockIdx.z,
smem_,
buffer,
row,
col,
val);
}
}
}
}
}
// Move the read offset to next buffer.
inline __device__ void move_to_next_read_buffer() {
// if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
// this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
// } else if( BUFFERS_PER_TILE > 1 ) {
// this->smem_read_buffer_ += BYTES_PER_BUFFER;
// }
if( BUFFERS_PER_TILE > 1 && smem_read_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
this->smem_read_offset_ -= BYTES_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_read_offset_ += BYTES_PER_BUFFER;
}
}
// Move the read offset to next buffer. TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer() {
this->move_to_next_read_buffer();
}
// Move the read offset to next N buffer (circular-buffer).
inline __device__ void move_to_next_read_buffer(int N) {
if( BUFFERS_PER_TILE > 1 ) {
// this->smem_read_buffer_ += N * BYTES_PER_BUFFER;
// this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
this->smem_read_offset_ += N * BYTES_PER_BUFFER;
this->smem_read_offset_ -= smem_read_offset_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
}
}
// Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer(int N) {
this->move_to_next_read_buffer(N);
}
// Move the write offset to next buffer.
inline __device__ void move_to_next_write_buffer() {
// if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
// this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
// } else if( BUFFERS_PER_TILE > 1 ) {
// this->smem_write_buffer_ += BYTES_PER_BUFFER;
// }
if( BUFFERS_PER_TILE > 1 && smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
this->smem_write_offset_ -= BYTES_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_write_offset_ += BYTES_PER_BUFFER;
}
}
// Move the write offset to next buffer. TODO: Remove that member function!
inline __device__ void move_next_write_buffer() {
this->move_to_next_write_buffer();
}
// Move the read offset.
inline __device__ void move_read_offset(int delta) {
this->smem_read_offset_ += delta;
}
// Move the write offset.
inline __device__ void move_write_offset(int delta) {
this->smem_write_offset_ += delta;
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
// Trying to reduce the shared mem for Q from 4KB per buffer to 2KB per buffer.
if (!PARTIAL_STORE || (tidx_ / THREADS_PER_ROW < ROWS)) {
sts(smem_ptrs, data);
}
}
// Store to the tile in shared memory.
template< int N, int M >
inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) {
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data, preds);
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) {
this->store(data, preds);
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) {
uint32_t tmp[1] = { preds };
this->store(gmem_ptrs, tmp);
}
// The shared memory pointer.
const uint32_t smem_;
// The read offset. Reserve 4 offsets if needed.
int smem_read_offset_;
// The write offset.
int smem_write_offset_;
// The buffer base offset for read.
// int smem_read_buffer_;
// The buffer base offset for write.
// int smem_write_buffer_;
const int tidx_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true
>
struct Smem_tile_a {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K, int MMAS_K_WITH_PADDING >
struct Compute_reset_mask {
// The potential mask.
enum { HALF = MMAS_K_WITH_PADDING / 2 };
// The remainder.
enum { MOD = MMAS_K % HALF };
// The final value.
enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask<MOD, HALF>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K_WITH_PADDING >
struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> {
enum { VALUE = 0 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K >
struct Compute_reset_mask<MMAS_K, MMAS_K> {
enum { VALUE = MMAS_K - 1 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_a {
// The size in bits.
enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A };
// The number of rows.
enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a<N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a<Cta_tile::K>::VALUE
>
struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
Cta_tile::M,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_A,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::M,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_A,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1>;
// The fragment.
using Fragment = Fragment_a<Row>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// Ctor.
inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) {
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1);
static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);
// The row and column read by the thread.
int smem_read_row = (tidx & 0x0f);
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
smem_read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) {
#pragma unroll
for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) {
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
// ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset);
// Store the value into the fragment.
a[mi].reg(0) = tmp.x;
a[mi].reg(1) = tmp.y;
a[mi].reg(2) = tmp.z;
a[mi].reg(3) = tmp.w;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) {
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) {
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) {
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset() {
// The number of MMAs in the K dimension.
enum { MMAS_K = Mma_tile::MMAS_K };
// The number of MMAs in the K dimension when we include padding.
enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
// Assemble the mask.
enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_a<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_row_a<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_row_a<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true
>
struct Smem_tile_b {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_b {
// The size in bits.
enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B };
// The number of rows.
enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b<N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b<Cta_tile::K>::VALUE
>
struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
Cta_tile::N,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::N,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1>;
// The fragment.
using Fragment = Fragment_b< Col>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// The number of STS per thread
enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
// The number of STS per thread must be at least 1.
enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };
// Ctor.
inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) {
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);
static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1);
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
// The divisor for the warps.
const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
// The row and column read by the thread.
int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA +
(tidx & 0x07) +
(tidx & 0x10) / 2;
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
smem_read_col ^= (tidx & 0x08) / 8;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
// ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset);
// Store the value into the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) {
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) {
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) {
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset() {
// The number of MMAs in the K dimension.
enum { MMAS_K = Mma_tile::MMAS_K };
// The number of MMAs in the K dimension when we include padding.
enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
// Assemble the mask.
enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE >
: public Smem_tile_col_b<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b<Cta_tile::N>::VALUE,
// How many cols to use for the XOR pattern to avoid bank conflicts?
int COLS_PER_XOR_PATTERN_ = 1
>
struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
Cta_tile::K,
Cta_tile::N,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
COLS_PER_XOR_PATTERN_> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::K,
Cta_tile::N,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
COLS_PER_XOR_PATTERN_>;
// The fragment.
using Fragment = Fragment_b<Row>;
// Can we use LDSM? No if the data type is 32-bit large.
enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 };
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 };
// The number of elements per LDS.
enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B };
// The number of STS per thread
enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
// The number of STS per thread must be at least 1.
enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };
// Ctor.
inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) {
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_K == 1);
static_assert(WARPS_M == 4 || WARPS_M == 8);
static_assert(WARPS_N == 1);
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
const int WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
static_assert(USE_LDSMT);
static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);
// The row/col read by the thread.
int smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 +
(tidx & 0x07) + (tidx & 0x08);
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
// Fill zeroes for group conv
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// The size of each element in bits.
const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( BYTES_PER_MMA_PER_CTA >= 128 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
Mma_tile::MMAS_N % 2 == 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
// The size of each element in bits.
const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
// uint32_t smem_read_og = this->smem_ + this->smem_read_offset_;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Prepare the offset.
int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW_BEFORE_PACKING;
if ( BYTES_PER_MMA_PER_CTA == 32 ) {
offset += this->smem_read_offset_;
} else if ( BYTES_PER_MMA_PER_CTA == 64 ) {
offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2;
} else {
offset += this->smem_read_offset_ + (ni ) * BYTES_PER_MMA_PER_CTA;
}
// Load the data using LDSM.MT88.2.
// uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;
uint32_t ptr = this->smem_ + offset;
uint4 tmp;
if( USE_LDSMT ) {
ldsmt(tmp, ptr);
} else {
lds(tmp.x, (ptr ) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING);
lds(tmp.y, (ptr ) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING);
lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING);
lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("BYTES_PER_MMA_PER_CTA=%d, ni = %d, smem_read diff = %d\n", BYTES_PER_MMA_PER_CTA, ni, ptr - smem_read_og);
// }
// Store those values in the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if( BYTES_PER_MMA_PER_CTA >= 128 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 8 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2));
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
Mma_tile::MMAS_N % 2 == 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_b<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_row_b<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_row_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, Rows_per_xor_pattern_col_b<Cta_tile::N>::VALUE, 1> {
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, Rows_per_xor_pattern_col_b<Cta_tile::N>::VALUE, 1>;
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The fragment.
using Fragment = Fragment_b< fmha::Col>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// Ctor.
inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) {
// The row/col read by the thread.
int read_row, read_col;
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f);
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
read_col = ((read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW_BEFORE_PACKING + read_col * BYTES_PER_LDS;
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Jump by 16 * #warps row.
int row = ki * 16 * Cta_tile::WARPS_K;
// Load the data using LDSM.MT88.2.
uint4 tmp;
fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW_BEFORE_PACKING);
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if( Mma_tile::MMAS_N == 1 ) {
// noop
} else if( Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
} else if( Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if (Mma_tile::MMAS_N == 8) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2));
} else {
assert(false); // Not implemented!
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Smem_tile_o {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The accumulators.
using Accumulator = fmha::Fragment_accumulator;
// The accumulators.
using Data_type = typename Accumulator::Data_type;
// The size of each element.
static constexpr int BYTES_PER_ELEMENT = sizeof(Data_type);
// The size of each STS.
static constexpr int BYTES_PER_STS = 8;
// The size of each row in shared memory.
static constexpr int BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT;
// The size of each LDS.
static constexpr int BYTES_PER_LDS = 16;
static constexpr int THREADS_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT / BYTES_PER_LDS;
// The number of rows.
static constexpr int ROWS = Cta_tile::M;
// The number of "rows" to process per loop iteration (in the "epilogue").
static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA;
// The number of outer loops.
static constexpr int LOOPS = ROWS / ROWS_PER_LOOP;
// Make sure it matches our expectations.
static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");
// The number of rows loaded per LDS.
static constexpr int ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
// Do we have to guard against partial writes/reads.
static constexpr bool HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0;
// The total number of LDS per loop.
static constexpr int LDS_PER_LOOP = fmha::DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_LDS);
// The amount of shared memory.
static constexpr int BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW;
// The write pointer.
uint32_t smem_write_, smem_read_;
// Is the thread active for the last LDS of the series?
int is_active_for_last_lds_;
// static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K);
static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");
// Ctor.
inline __device__ Smem_tile_o(void *smem, int tidx) {
// Get a 32-bit value for the shared memory address.
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
static_assert(Cta_tile::N == 16 || Cta_tile::N == 32 || Cta_tile::N == 64 || Cta_tile::N == 128);
int write_row = (tidx & 0x1c) / 4;
const int lane = tidx % 32;
const int warp = tidx / 32;
constexpr int ELEMENTS_PER_STS = BYTES_PER_STS / BYTES_PER_ELEMENT;
constexpr int STS_PER_WARP = 16 * Mma_tile::MMAS_N / ELEMENTS_PER_STS;
int write_col = warp * STS_PER_WARP + lane % STS_PER_WARP;
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("write_row = %d, write_col = %d\n", write_row, write_col);
// }
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && (write_col == 0)) {
// printf("threadIdx.x = %d\n", threadIdx.x);
// }
// Assemble the write pointer.
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// The element read by each thread.
int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
// Take the XOR pattern into account for the column.
read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8)));
// read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8))));
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("read_row = %d, read_col = %d\n", read_row, read_col);
// }
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && (read_col == 0)) {
// printf("threadIdx.x = %d\n", threadIdx.x);
// }
// Assemble the read pointer.
this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// Is that thread active on the last LDS?
if( HAS_INCOMPLETE_LDS ) {
this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M;
}
}
// Load the output fragments.
template <bool zero_init=true>
inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {
#pragma unroll
for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {
// Load the elements before the reduction (split-K).
uint4 tmp[Cta_tile::WARPS_K];
#pragma unroll
for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
uint32_t smem_read = this->smem_read_ + imm;
// TD [2022-06-05] Ugly fix for d=128 in the forward pass, maybe there's a better way.
if ((Cta_tile::N == 128) && (ROWS_PER_LDS == 4) && (ii % 2 == 1)) {
smem_read ^= 8 * BYTES_PER_LDS;
}
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("imm diff = %d\n", smem_read - this->smem_read_);
// }
if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
// fmha::lds(tmp[jj], this->smem_read_ + imm);
fmha::lds(tmp[jj], smem_read);
}
}
// Perform the reduction.
out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction: out = %.6f\n", reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
#pragma unroll
for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
out[ii] = fmha::fadd4(out[ii], tmp[jj]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast<float (&)[4]>(tmp[jj])[0], reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
}
}
}
// Store the accumulators.
template <int M, int N>
inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
// uint32_t smem_write_og = this->smem_write_;
static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// The number of MMAs that are stored per loop iteration.
static constexpr int MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS;
// Store 1st column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
// Swizzle the write pointer using a XOR of 16B.
this->smem_write_ ^= 32;
// Store 2nd column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
static_assert(Mma_tile::MMAS_N <= 8, "Not implemented");
if( Mma_tile::MMAS_N >= 8 && ni % 4 == 3 ) {
this->smem_write_ ^= 15 * 32;
} else if( Mma_tile::MMAS_N >= 4 && ni % 2 == 1 ) {
this->smem_write_ ^= 7 * 32;
} else if( Mma_tile::MMAS_N >= 2 ) {
this->smem_write_ ^= 3 * 32;
} else {
this->smem_write_ ^= 3 * 32;
}
// this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_mma {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
using Fragment = fmha::Fragment_a<fmha::Col>;
enum { COLS = Cta_tile::N };
enum { BYTES_PER_ELT = 2 };
enum { BYTES_PER_STS = 4 };
enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO
enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
enum { WARPS_K = Cta_tile::WARPS_K };
static_assert(WARPS_K == 1);
inline __device__ Smem_tile_mma(char *smem, int tidx) {
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
int write_col, write_row;
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_M == 8) || WARPS_N == 1);
if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
write_row = (tidx & 0x1c) / 4;
write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
write_col ^= (write_row & 0x07) * 4;
} else {
write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
write_col = (tidx & 0x03);
// write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x0f)))) * 4;
write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x07)))) * 4;
}
// write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
}
template<int M, int N>
inline __device__ void store(const uint4 (®s)[M][N]) {
static_assert(COLS == Cta_tile::N);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
// offset ^= 4 * BYTES_PER_STS;
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
// size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * BYTES_PER_STS;
fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
}
}
}
template<typename Fragment, int M, int N>
inline __device__ void store(const Fragment (&frag)[N][M]) {
static_assert(COLS == Cta_tile::N);
uint4 regs[M][N];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// Need to transpose ref(1) and reg(2) here since when we load it we transpose again.
regs[mi][ni] = make_uint4(frag[ni][mi].reg(0), frag[ni][mi].reg(2),
frag[ni][mi].reg(1), frag[ni][mi].reg(3));
}
}
this->store(regs);
}
// uint32_t smem_;
// uint32_t write_offset_;
uint32_t smem_write_;
};
template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_transposed : public Base {
enum { BYTES_PER_LDS = 16 };
enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
using Fragment = typename Base::Fragment;
inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) {
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
// read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x0f))));
read_col ^= (read_row & 0x07);
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
template<int M, int N>
inline __device__ void load(Fragment (&frag)[M][N]) {
static_assert(Base::COLS == Cta_tile::N);
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
// fmha::ldsmt(dst, this->smem_ + offset);
// size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::ldsmt(dst, offset);
frag[mi][ni].reg(0) = dst.x;
frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major!
frag[mi][ni].reg(2) = dst.y;
frag[mi][ni].reg(3) = dst.w;
}
}
}
// uint32_t read_offset_;
uint32_t smem_read_;
};
template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_epilogue : public Base {
enum { BYTES_PER_LDS = 16 };
enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS };
static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW);
enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS };
static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M);
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
using Acc = fmha::Fragment_accumulator;
inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) {
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
const int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
// read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07)));
static_assert(Base::BYTES_PER_ROW == 32 || Base::BYTES_PER_ROW == 64 || Base::BYTES_PER_ROW == 128 || Base::BYTES_PER_ROW == 256);
read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x07))));
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
inline __device__ void load(uint4 (&data)[NUM_LDS]) {
for( int ii = 0; ii < NUM_LDS; ii++ ) {
// size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
// fmha::lds(data[ii], this->smem_ + offset);
// size_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
uint32_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
fmha::lds(data[ii], offset);
}
}
template<typename elem_type=__half, int M, int N>
inline __device__ void store(const Acc (&acc)[M][N]){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// 1st row - 4 elements per row.
float tmp00 = acc[mi][ni].elt(0);
float tmp01 = acc[mi][ni].elt(1);
float tmp02 = acc[mi][ni].elt(4);
float tmp03 = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc[mi][ni].elt(2);
float tmp11 = acc[mi][ni].elt(3);
float tmp12 = acc[mi][ni].elt(6);
float tmp13 = acc[mi][ni].elt(7);
uint32_t x = fmha::float2_pack<elem_type>(tmp00, tmp01);
uint32_t y = fmha::float2_pack<elem_type>(tmp02, tmp03);
uint32_t z = fmha::float2_pack<elem_type>(tmp10, tmp11);
uint32_t w = fmha::float2_pack<elem_type>(tmp12, tmp13);
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);
// offset ^= 4 * Base::BYTES_PER_STS;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
// size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_);
// }
fmha::sts(offset + 0 * BYTES_PER_ROW, x);
fmha::sts(offset + 8 * BYTES_PER_ROW, z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(offset + 0 * BYTES_PER_ROW, y);
fmha::sts(offset + 8 * BYTES_PER_ROW, w);
}
}
}
template<int M, int N>
inline __device__ void store(const uint4 (®s)[M][N]) {
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
}
}
}
// uint32_t read_offset_;
uint32_t smem_read_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_transpose {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
using Fragment_write = fmha::Fragment_b<fmha::Col>;
using Fragment_read = fmha::Fragment_b<fmha::Col>;
enum { COLS = Cta_tile::N };
enum { BYTES_PER_ELT = 2 };
enum { BYTES_PER_STS = 4 };
enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO
enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };
enum { BYTES_PER_LDS = 16 };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
enum { WARPS_K = Cta_tile::WARPS_K };
static_assert(WARPS_K == 1);
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
inline __device__ Smem_tile_transpose(char *smem, int tidx) {
smem_ = __nvvm_get_smem_pointer(smem);
// uint32_t smem_ = __nvvm_get_smem_pointer(smem);
int write_col, write_row;
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
write_row = (tidx & 0x1c) / 4;
write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
} else {
write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
write_col = (tidx & 0x03);
}
write_col ^= (write_row & 0x07) * 4;
write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
read_col ^= (read_row & 0x07);
read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
template<int M, int N>
inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
}
}
template<int N>
inline __device__ void load(Fragment_read (&frag_r)[N]) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
frag_r[ni].reg(0) = dst.x;
frag_r[ni].reg(1) = dst.y; // Fragment B regs col major!
frag_r[ni].reg(2) = dst.z;
frag_r[ni].reg(3) = dst.w;
}
}
template<int M, int N>
inline __device__ void transpose(const Fragment_write (&frag_w)[M][N], Fragment_read (&frag_r)[M], int mi) {
static_assert(COLS == Cta_tile::N);
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
}
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
frag_r[ni].reg(0) = dst.x;
frag_r[ni].reg(1) = dst.y; // Fragment B regs col major!
frag_r[ni].reg(2) = dst.z;
frag_r[ni].reg(3) = dst.w;
}
}
uint32_t smem_;
uint32_t write_offset_;
uint32_t read_offset_;
// uint32_t smem_write_;
// uint32_t smem_read_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Gmem_tile,
// The number of buffers. (Used in multistage and double buffer cases.)
int BUFFERS_PER_TILE_ = 1
>
struct Smem_tile_dp_sum {
using Cta_tile = typename Gmem_tile::Cta_tile;
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The size of each element.
static constexpr int BYTES_PER_ELEMENT = 4;
static constexpr int ROWS = Gmem_tile::ROWS;
static constexpr int THREADS_PER_ROW = Gmem_tile::THREADS_PER_ROW;
static constexpr int MMAS_M = Mma_tile::MMAS_M;
static constexpr int ROWS_PER_LDG = Gmem_tile::ROWS_PER_LDG;
static constexpr int LDGS = Gmem_tile::LDGS;
static constexpr int ROWS_PER_MMA = Mma_tile::M_PER_MMA;
// The size of one buffer in bytes in shared memory.
static constexpr int BYTES_PER_BUFFER = ROWS * BYTES_PER_ELEMENT;
// The number of buffers.
static constexpr int BUFFERS_PER_TILE = BUFFERS_PER_TILE_;
// The size in bytes of total buffers.
static constexpr int BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE;
// The boundary for smem_read_offset and smem_write_offset increment.
static constexpr int ROWS_PER_TILE_INC_BOUNDARY = ROWS * BUFFERS_PER_TILE - ROWS;
inline __device__ Smem_tile_dp_sum(float *smem, const int tidx)
: smem_(smem), smem_read_buffer_(smem), smem_write_buffer_(smem), tidx_(tidx) {
}
// Move the read offset to next buffer.
inline __device__ void move_to_next_read_buffer() {
if( BUFFERS_PER_TILE > 1 && (smem_read_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) {
this->smem_read_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_read_buffer_ += ROWS;
}
}
// Move the write offset to next buffer.
inline __device__ void move_to_next_write_buffer() {
if( BUFFERS_PER_TILE > 1 && (smem_write_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) {
this->smem_write_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_write_buffer_ += ROWS;
}
}
inline __device__ void store(const float (&sum)[LDGS]) {
if (tidx_ % THREADS_PER_ROW == 0) {
int row = tidx_ / THREADS_PER_ROW;
#pragma unroll
for (int i = 0; i < LDGS; ++i) {
if (row + i * ROWS_PER_LDG < ROWS) {
smem_write_buffer_[row + i * ROWS_PER_LDG] = sum[i];
}
}
}
}
inline __device__ void store(const float sum, const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS;
int row = tidx_ / THREADS_PER_ROW;
if ((row < ROWS) && (tidx_ % THREADS_PER_ROW == 0)) {
smem_write[row] = sum;
}
}
inline __device__ void store(const float (&sum)[LDGS], const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS;
if (tidx_ % THREADS_PER_ROW == 0) {
int row = tidx_ / THREADS_PER_ROW;
#pragma unroll
for (int i = 0; i < LDGS; ++i) {
if (row + i * ROWS_PER_LDG < ROWS) {
smem_write[row + i * ROWS_PER_LDG] = sum[i];
}
}
}
}
inline __device__ void store_pair(const float (&sum)[MMAS_M * 2]) {
float *smem_write = smem_;
// Extract the position in the warp.
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
int row = lane / 4;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0];
smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1];
}
}
inline __device__ void store_pair(const float (&sum)[MMAS_M * 2], const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS;
// Extract the position in the warp.
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
int row = lane / 4;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0];
smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1];
}
}
template<int N>
inline __device__ void load(float (&sum)[N], const int (&row)[N]) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
sum[ni] = smem_read_buffer_[row[ni]];
}
}
template<int N>
inline __device__ void load(float (&sum)[N], const int (&row)[N], const int buffer_idx) {
float *smem_read = smem_ + buffer_idx * ROWS;
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
sum[ni] = smem_read[row[ni]];
}
}
static inline __device__ float reduce_warp(float sum) {
fmha::SumOp<float> sum_op;
return fmha::Allreduce<THREADS_PER_ROW>::run(sum, sum_op);
}
const int tidx_;
float * const smem_;
float *smem_read_buffer_;
float *smem_write_buffer_;
};
} // namespace fmha