in kernels/fmha/smem_tile.h [1260:1281]
inline __device__ void store(const uint4 (®s)[M][N]) {
static_assert(COLS == Cta_tile::N);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
// offset ^= 4 * BYTES_PER_STS;
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
// size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * BYTES_PER_STS;
fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
}
}
}