inline __device__ void reverse_smem_read_offset()

in candle-flash-attn-v1/kernels/fmha/smem_tile.h [813:841]


    inline __device__ void reverse_smem_read_offset(int ki = 0) {
        // The size of each element in bits.
        const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
        // The size in bytes of the data needed to compute an MMA per CTA.
        const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;

        #pragma unroll
        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
            // Undo the pointer increment for the next ni.
            // Should match the load function below for ki = 0.
            if( BYTES_PER_MMA_PER_CTA >= 128 ) {
                // Nothing to do!
            } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
                this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
            } else if( BYTES_PER_MMA_PER_CTA == 64 ) {
                // Nothing to do!
            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
            }
        }

        // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
        if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
                Mma_tile::MMAS_N % 2 == 1 ) {
            this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
        }
    }