inline __device__ void load()

in kernels/fmha/smem_tile.h [965:992]


    inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
            // Jump by 16 * #warps row.
            int row = ki * 16 * Cta_tile::WARPS_K;

            // Load the data using LDSM.MT88.2.
            uint4 tmp;
            fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW_BEFORE_PACKING);
            b[ni].reg(0) = tmp.x;
            b[ni].reg(1) = tmp.y;
            b[ni].reg(2) = tmp.z;
            b[ni].reg(3) = tmp.w;

            // Move the pointer for the next ni. I expect the compiler to not recompute those.
            if( Mma_tile::MMAS_N == 1 ) {
                // noop
            } else if( Mma_tile::MMAS_N == 2 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
            } else if( Mma_tile::MMAS_N == 4 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
            } else if (Mma_tile::MMAS_N == 8) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2));
            } else {
                assert(false);  // Not implemented!
            }
        }
    }