in candle-flash-attn-v1/kernels/fmha/smem_tile.h [777:810]
inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) {
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_K == 1);
static_assert(WARPS_M == 4 || WARPS_M == 8);
static_assert(WARPS_N == 1);
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
const int WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
static_assert(USE_LDSMT);
static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);
// The row/col read by the thread.
int smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 +
(tidx & 0x07) + (tidx & 0x08);
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
// Fill zeroes for group conv
}