inline __device__ void store()

in kernels/fmha/smem_tile.h [1388:1427]


    inline __device__ void store(const Acc (&acc)[M][N]){
        #pragma unroll
        for( int mi = 0; mi < M; mi++ ) {
            #pragma unroll
            for( int ni = 0; ni < N; ni++ ) {
                // 1st row - 4 elements per row.
                float tmp00 = acc[mi][ni].elt(0);
                float tmp01 = acc[mi][ni].elt(1);
                float tmp02 = acc[mi][ni].elt(4);
                float tmp03 = acc[mi][ni].elt(5);
                // 2nd row - 4 elements per row.
                float tmp10 = acc[mi][ni].elt(2);
                float tmp11 = acc[mi][ni].elt(3);
                float tmp12 = acc[mi][ni].elt(6);
                float tmp13 = acc[mi][ni].elt(7);

                uint32_t x = fmha::float2_pack<elem_type>(tmp00, tmp01);
                uint32_t y = fmha::float2_pack<elem_type>(tmp02, tmp03);
                uint32_t z = fmha::float2_pack<elem_type>(tmp10, tmp11);
                uint32_t w = fmha::float2_pack<elem_type>(tmp12, tmp13);

                // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
                // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
                // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);
                // offset ^= 4 * Base::BYTES_PER_STS;
                // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
                // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
                // size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
                uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
                // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
                //     printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_);
                // }
                fmha::sts(offset + 0 * BYTES_PER_ROW, x);
                fmha::sts(offset + 8 * BYTES_PER_ROW, z);
                offset ^= 4 * Base::BYTES_PER_STS;
                fmha::sts(offset + 0 * BYTES_PER_ROW, y);
                fmha::sts(offset + 8 * BYTES_PER_ROW, w);
            }
        }
    }