in candle-flash-attn-v1/kernels/fmha/smem_tile.h [1139:1217]
inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
// uint32_t smem_write_og = this->smem_write_;
static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// The number of MMAs that are stored per loop iteration.
static constexpr int MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS;
// Store 1st column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
// Swizzle the write pointer using a XOR of 16B.
this->smem_write_ ^= 32;
// Store 2nd column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
static_assert(Mma_tile::MMAS_N <= 8, "Not implemented");
if( Mma_tile::MMAS_N >= 8 && ni % 4 == 3 ) {
this->smem_write_ ^= 15 * 32;
} else if( Mma_tile::MMAS_N >= 4 && ni % 2 == 1 ) {
this->smem_write_ ^= 7 * 32;
} else if( Mma_tile::MMAS_N >= 2 ) {
this->smem_write_ ^= 3 * 32;
} else {
this->smem_write_ ^= 3 * 32;
}
// this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
}
}