inline __device__ void compute_store_pointers()

in candle-flash-attn-v1/kernels/fmha/smem_tile.h [145:171]


    inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) {
        #pragma unroll
        for( int ii = 0; ii < N; ++ii ) {
            // Decompose the STS into row/col.
            int row = ii / STS_PER_ROW;
            int col = ii % STS_PER_ROW;

            // Assemble the offset.
            int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW;

            // Take the column into account.
            if( STS_PER_ROW > 1 ) {
                offset += col*THREADS_PER_ROW*BYTES_PER_STS;
            }

            // Apply the XOR pattern if needed.
            if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) {
                const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;
                offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;
            }

            // Assemble the final pointer :)
            // ptrs[ii] = smem_ + offset + smem_write_buffer_;
            // smem_write_buffer_ is already merged with smem_write_offset_
            ptrs[ii] = smem_ + offset;
        }
    }