inline __device__ void store()

in kernels/fmha/smem_tile.h [1260:1281]


    inline __device__ void store(const uint4 (&regs)[M][N]) {
        static_assert(COLS == Cta_tile::N);
        #pragma unroll
        for( int mi = 0; mi < M; mi++ ) {
            #pragma unroll
            for( int ni = 0; ni < N; ni++ ) {
                // size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
                // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
                // offset ^= 4 * BYTES_PER_STS;
                // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
                // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
                // size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
                fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
                offset ^= 4 * BYTES_PER_STS;
                fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
                fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
            }
        }
    }