inline __device__ void apply_dropout_16bits()

in kernels/fmha/softmax.h [344:382]


    inline __device__ void apply_dropout_16bits(Philox &ph0, Philox &ph1, uint16_t p_dropout_in_uint16_t) {
        // We encode the dropout pattern in the sign bit of the non-negative
        // softmax to distinguish from pre-existing zeros
        auto encode_dropout = [](bool keep, float val) {
            return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
        };
        #pragma unroll
        for( int mi = 0; mi < MMAS_M; mi++ ) {
            static_assert(MMAS_N % 2 == 0);
            #pragma unroll
            for( int ni = 0; ni < MMAS_N; ni += 2 ) {
                uint16_t tmp[8];
                fmha::uint4_to_ushort8(ph0(), tmp);
                // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
                //     printf("ni = %d, ph  Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
                // }
                #pragma unroll
                for (int ii = 0; ii < 2; ++ii) {
                    #pragma unroll
                    for (int jj = 0; jj < 4; ++jj) {
                        elt_[mi * 2 + ii][4 * ni + jj] =
                            encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
                    }
                }
                fmha::uint4_to_ushort8(ph1(), tmp);
                // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
                //     printf("ni = %d, ph  Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
                // }
                #pragma unroll
                for (int ii = 0; ii < 2; ++ii) {
                    #pragma unroll
                    for (int jj = 0; jj < 4; ++jj) {
                        elt_[mi * 2 + ii][4 * (ni + 1) + jj] =
                            encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * (ni + 1) + jj]);
                    }
                }
            }
        }
    }