inline __device__ Smem_tile_o()

in kernels/fmha/smem_tile.h [1046:1095]


    inline __device__ Smem_tile_o(void *smem, int tidx) {

        // Get a 32-bit value for the shared memory address.
        uint32_t smem_ = __nvvm_get_smem_pointer(smem);

        static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
        static_assert(Cta_tile::N == 16 || Cta_tile::N == 32 || Cta_tile::N == 64 || Cta_tile::N == 128);

        int write_row = (tidx & 0x1c) / 4;

        const int lane = tidx % 32;
        const int warp = tidx / 32;

        constexpr int ELEMENTS_PER_STS = BYTES_PER_STS / BYTES_PER_ELEMENT;
        constexpr int STS_PER_WARP = 16 * Mma_tile::MMAS_N / ELEMENTS_PER_STS;
        int write_col = warp * STS_PER_WARP + lane % STS_PER_WARP;

        // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
        //     printf("write_row = %d, write_col = %d\n", write_row, write_col);
        // }

        // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && (write_col == 0)) {
        //     printf("threadIdx.x = %d\n", threadIdx.x);
        // }

        // Assemble the write pointer.
        smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;

        // The element read by each thread.
        int read_row = tidx / THREADS_PER_ROW;
        int read_col = tidx % THREADS_PER_ROW;

        // Take the XOR pattern into account for the column.
        read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8)));
        // read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8))));

        // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
        //     printf("read_row = %d, read_col = %d\n", read_row, read_col);
        // }
        // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && (read_col == 0)) {
        //     printf("threadIdx.x = %d\n", threadIdx.x);
        // }
        // Assemble the read pointer.
        this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;

        // Is that thread active on the last LDS?
        if( HAS_INCOMPLETE_LDS ) {
            this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M;
        }
    }