in candle-flash-attn-v1/kernels/fmha/smem_tile.h [1099:1135]
inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {
#pragma unroll
for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {
// Load the elements before the reduction (split-K).
uint4 tmp[Cta_tile::WARPS_K];
#pragma unroll
for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
uint32_t smem_read = this->smem_read_ + imm;
// TD [2022-06-05] Ugly fix for d=128 in the forward pass, maybe there's a better way.
if ((Cta_tile::N == 128) && (ROWS_PER_LDS == 4) && (ii % 2 == 1)) {
smem_read ^= 8 * BYTES_PER_LDS;
}
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("imm diff = %d\n", smem_read - this->smem_read_);
// }
if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
// fmha::lds(tmp[jj], this->smem_read_ + imm);
fmha::lds(tmp[jj], smem_read);
}
}
// Perform the reduction.
out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction: out = %.6f\n", reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
#pragma unroll
for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
out[ii] = fmha::fadd4(out[ii], tmp[jj]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast<float (&)[4]>(tmp[jj])[0], reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
}
}
}