inline __device__ void scale()

in kernels/fmha/softmax.h [385:401]


    inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
        // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
        float inv_sum[MMAS_M * 2];
        #pragma unroll
        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
            inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];
        }

        // Update the values.
        #pragma unroll
        for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
            #pragma unroll
            for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
                elt_[mi][ni] *= inv_sum[mi];
            }
        }
    }