inline __device__ void transpose()

in candle-flash-attn-v1/kernels/fmha/smem_tile.h [1529:1553]


    inline __device__ void transpose(const Fragment_write (&frag_w)[M][N], Fragment_read (&frag_r)[M], int mi) {
        static_assert(COLS == Cta_tile::N);
        #pragma unroll
        for( int ni = 0; ni < N; ni++ ) {
            // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
            fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
            offset ^= 4 * BYTES_PER_STS;
            fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
            fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
        }
        #pragma unroll
        for( int ni = 0; ni < N; ni++ ) {
            // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            uint4 dst;
            fmha::ldsmt(dst, this->smem_ + offset);
            frag_r[ni].reg(0) = dst.x;
            frag_r[ni].reg(1) = dst.y;  // Fragment B regs col major!
            frag_r[ni].reg(2) = dst.z;
            frag_r[ni].reg(3) = dst.w;
        }
    }