in kernels/fmha/smem_tile.h [1636:1647]
inline __device__ void store(const float (&sum)[LDGS], const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS;
if (tidx_ % THREADS_PER_ROW == 0) {
int row = tidx_ / THREADS_PER_ROW;
#pragma unroll
for (int i = 0; i < LDGS; ++i) {
if (row + i * ROWS_PER_LDG < ROWS) {
smem_write[row + i * ROWS_PER_LDG] = sum[i];
}
}
}
}