in kernels/fmha/smem_tile.h [1328:1344]
inline __device__ void load(Fragment (&frag)[M][N]) {
static_assert(Base::COLS == Cta_tile::N);
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
// fmha::ldsmt(dst, this->smem_ + offset);
// size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::ldsmt(dst, offset);
frag[mi][ni].reg(0) = dst.x;
frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major!
frag[mi][ni].reg(2) = dst.y;
frag[mi][ni].reg(3) = dst.w;
}
}
}