inline __device__ void apply_dropout_16bits()

in candle-flash-attn-v1/kernels/fmha/softmax.h [281:308]


    inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) {
        // We encode the dropout pattern in the sign bit of the non-negative
        // softmax to distinguish from pre-existing zeros
        auto encode_dropout = [](bool keep, float val) {
            return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
        };
        #pragma unroll
        for( int mi = 0; mi < MMAS_M; mi++ ) {
            #pragma unroll
            for( int ni = 0; ni < MMAS_N; ni++ ) {
                uint16_t tmp[8];
                // fmha::uint4_to_ushort8(ph(), tmp);
                uint4 tmp_32 = ph();
                fmha::uint4_to_ushort8(tmp_32, tmp);
                // if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
                //     printf("tidx = %d, ni = %d, ph  Philox: %u, %u, %u, %u\n", threadIdx.x, ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
                // }
                #pragma unroll
                for (int ii = 0; ii < 2; ++ii) {
                    #pragma unroll
                    for (int jj = 0; jj < 4; ++jj) {
                        elt_[mi * 2 + ii][4 * ni + jj] =
                            encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
                    }
                }
            }
        }
    }