in candle-flash-attn-v1/kernels/fmha/smem_tile.h [813:841]
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// The size of each element in bits.
const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( BYTES_PER_MMA_PER_CTA >= 128 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
Mma_tile::MMAS_N % 2 == 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}