in candle-flash-attn-v1/kernels/fmha/smem_tile.h [1472:1497]
inline __device__ Smem_tile_transpose(char *smem, int tidx) {
smem_ = __nvvm_get_smem_pointer(smem);
// uint32_t smem_ = __nvvm_get_smem_pointer(smem);
int write_col, write_row;
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
write_row = (tidx & 0x1c) / 4;
write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
} else {
write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
write_col = (tidx & 0x03);
}
write_col ^= (write_row & 0x07) * 4;
write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
read_col ^= (read_row & 0x07);
read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}