in kernels/fmha/smem_tile.h [1388:1427]
inline __device__ void store(const Acc (&acc)[M][N]){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// 1st row - 4 elements per row.
float tmp00 = acc[mi][ni].elt(0);
float tmp01 = acc[mi][ni].elt(1);
float tmp02 = acc[mi][ni].elt(4);
float tmp03 = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc[mi][ni].elt(2);
float tmp11 = acc[mi][ni].elt(3);
float tmp12 = acc[mi][ni].elt(6);
float tmp13 = acc[mi][ni].elt(7);
uint32_t x = fmha::float2_pack<elem_type>(tmp00, tmp01);
uint32_t y = fmha::float2_pack<elem_type>(tmp02, tmp03);
uint32_t z = fmha::float2_pack<elem_type>(tmp10, tmp11);
uint32_t w = fmha::float2_pack<elem_type>(tmp12, tmp13);
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);
// offset ^= 4 * Base::BYTES_PER_STS;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
// size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_);
// }
fmha::sts(offset + 0 * BYTES_PER_ROW, x);
fmha::sts(offset + 8 * BYTES_PER_ROW, z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(offset + 0 * BYTES_PER_ROW, y);
fmha::sts(offset + 8 * BYTES_PER_ROW, w);
}
}
}