inline __device__ void load()

in kernels/fmha/smem_tile.h [1099:1135]


    inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {
        #pragma unroll
        for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {

            // Load the elements before the reduction (split-K).
            uint4 tmp[Cta_tile::WARPS_K];
            #pragma unroll
            for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
                int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
                uint32_t smem_read = this->smem_read_ + imm;
                // TD [2022-06-05] Ugly fix for d=128 in the forward pass, maybe there's a better way.
                if ((Cta_tile::N == 128) && (ROWS_PER_LDS == 4) && (ii % 2 == 1)) {
                    smem_read ^= 8 * BYTES_PER_LDS;
                }
                // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
                //     printf("imm diff = %d\n", smem_read - this->smem_read_);
                // }
                if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
                    // fmha::lds(tmp[jj], this->smem_read_ + imm);
                    fmha::lds(tmp[jj], smem_read);
                }
            }

            // Perform the reduction.
            out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]);
            // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
            //     printf("out reduction: out = %.6f\n", reinterpret_cast<float (&)[4]>(out[ii])[0]);
            // }
            #pragma unroll
            for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
                out[ii] = fmha::fadd4(out[ii], tmp[jj]);
                // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
                //     printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast<float (&)[4]>(tmp[jj])[0], reinterpret_cast<float (&)[4]>(out[ii])[0]);
                // }
            }
        }
    }