in candle-flash-attn-v1/kernels/fmha/smem_tile.h [1529:1553]
inline __device__ void transpose(const Fragment_write (&frag_w)[M][N], Fragment_read (&frag_r)[M], int mi) {
static_assert(COLS == Cta_tile::N);
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
}
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
frag_r[ni].reg(0) = dst.x;
frag_r[ni].reg(1) = dst.y; // Fragment B regs col major!
frag_r[ni].reg(2) = dst.z;
frag_r[ni].reg(3) = dst.w;
}
}