inline __device__ void load()

in kernels/fmha/smem_tile.h [1328:1344]


    inline __device__ void load(Fragment (&frag)[M][N]) {
        static_assert(Base::COLS == Cta_tile::N);
        for( int mi = 0; mi < M; mi++ ) {
            for( int ni = 0; ni < N; ni++ ) {
                // size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                uint4 dst;
                // fmha::ldsmt(dst, this->smem_ + offset);
                // size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                fmha::ldsmt(dst, offset);
                frag[mi][ni].reg(0) = dst.x;
                frag[mi][ni].reg(1) = dst.z;  // Fragment A regs col major!
                frag[mi][ni].reg(2) = dst.y;
                frag[mi][ni].reg(3) = dst.w;
            }
        }
    }