inline __device__ void apply_dropout_16bits()

in candle-flash-attn-v1/kernels/fmha/softmax.h [311:341]


    inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t,
                                                unsigned long long philox_subsequence) {
        // We encode the dropout pattern in the sign bit of the non-negative
        // softmax to distinguish from pre-existing zeros
        auto encode_dropout = [](bool keep, float val) {
            return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
        };
        static_assert(MMAS_M == 1);  // We're assuming 16x16 blocks.
        #pragma unroll
        for( int mi = 0; mi < MMAS_M; mi++ ) {
            #pragma unroll
            for( int ni = 0; ni < MMAS_N; ni++ ) {
                uint16_t tmp[8];
                // fmha::uint4_to_ushort8(ph(), tmp);
                fmha::uint4_to_ushort8(ph(philox_subsequence + ni * Cta_tile::WARPS_N), tmp);
                // uint4 tmp_32 = ph(philox_subsequence + ni * Cta_tile::WARPS_N);
                // fmha::uint4_to_ushort8(tmp_32, tmp);
                // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
                //     printf("ni = %d, ph  Philox: %u, %u, %u, %u\n", ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
                // }
                #pragma unroll
                for (int ii = 0; ii < 2; ++ii) {
                    #pragma unroll
                    for (int jj = 0; jj < 4; ++jj) {
                        elt_[mi * 2 + ii][4 * ni + jj] =
                            encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
                    }
                }
            }
        }
    }