in candle-flash-attn-v1/kernels/fmha/smem_tile.h [145:171]
inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) {
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
// Decompose the STS into row/col.
int row = ii / STS_PER_ROW;
int col = ii % STS_PER_ROW;
// Assemble the offset.
int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW;
// Take the column into account.
if( STS_PER_ROW > 1 ) {
offset += col*THREADS_PER_ROW*BYTES_PER_STS;
}
// Apply the XOR pattern if needed.
if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) {
const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;
offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;
}
// Assemble the final pointer :)
// ptrs[ii] = smem_ + offset + smem_write_buffer_;
// smem_write_buffer_ is already merged with smem_write_offset_
ptrs[ii] = smem_ + offset;
}
}