inline __device__ void load()

in candle-flash-attn-v1/kernels/fmha/smem_tile.h [844:906]


    inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
        // The size of each element in bits.
        const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
        // The size in bytes of the data needed to compute an MMA per CTA.
        const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;

        // uint32_t smem_read_og = this->smem_ + this->smem_read_offset_;
        #pragma unroll
        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
            // Prepare the offset.
            int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW_BEFORE_PACKING;
                if ( BYTES_PER_MMA_PER_CTA == 32 ) {
                    offset += this->smem_read_offset_;
                } else if ( BYTES_PER_MMA_PER_CTA == 64 ) {
                    offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2;
                } else {
                    offset += this->smem_read_offset_ + (ni  ) * BYTES_PER_MMA_PER_CTA;
                }

            // Load the data using LDSM.MT88.2.
            // uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;
            uint32_t ptr = this->smem_ + offset;
            uint4 tmp;
            if( USE_LDSMT ) {
                ldsmt(tmp, ptr);
            } else {
                lds(tmp.x, (ptr     ) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING);
                lds(tmp.y, (ptr     ) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING);
                lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING);
                lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING);
            }

            // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
            //     printf("BYTES_PER_MMA_PER_CTA=%d, ni = %d, smem_read diff = %d\n", BYTES_PER_MMA_PER_CTA, ni, ptr - smem_read_og);
            // }
            // Store those values in the fragment.
            b[ni].reg(0) = tmp.x;
            b[ni].reg(1) = tmp.y;
            b[ni].reg(2) = tmp.z;
            b[ni].reg(3) = tmp.w;

            // Move the pointer for the next ni. I expect the compiler to not recompute those.
            if( BYTES_PER_MMA_PER_CTA >= 128 ) {
                // Nothing to do!
            } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
                this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
            } else if( BYTES_PER_MMA_PER_CTA == 64 ) {
                // Nothing to do!
            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 8 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2));
            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
            }
        }

        // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
        if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
                Mma_tile::MMAS_N % 2 == 1 ) {
            this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
        }
    }