in candle-flash-attn-v1/kernels/fmha/softmax.h [311:341]
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t,
unsigned long long philox_subsequence) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
static_assert(MMAS_M == 1); // We're assuming 16x16 blocks.
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni++ ) {
uint16_t tmp[8];
// fmha::uint4_to_ushort8(ph(), tmp);
fmha::uint4_to_ushort8(ph(philox_subsequence + ni * Cta_tile::WARPS_N), tmp);
// uint4 tmp_32 = ph(philox_subsequence + ni * Cta_tile::WARPS_N);
// fmha::uint4_to_ushort8(tmp_32, tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * ni + jj] =
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
}
}
}
}
}