in candle-flash-attn/kernels/utils.h [297:359]
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, _));
}
}
// TD [2023-04-13]: Strange that the code below can cause race condition.
// I think it's because the copies are under an if statement.
// if (Is_even_K) {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(tiled_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, _));
// }
// }
// } else { // It's slightly faster in this case if iterate over K first
// #pragma unroll
// for (int k = 0; k < size<2>(S); ++k) {
// if (predicate_K(k)) {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(tiled_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, k));
// }
// }
// } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN
// if (Clear_OOB_MN || Is_even_MN) {
// clear(D(_, _, k));
// } else {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) {
// clear(D(_, m, k));
// }
// }
// }
// }
// }
// }
}