static inline void set_alpha()

in candle-flash-attn-v1/kernels/fmha_utils.h [58:77]


static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {
    if( dtype == DATA_TYPE_FP16 ) {
        half x = __float2half_rn( norm );
        uint16_t h = reinterpret_cast<const uint16_t &>( x );
        ushort2 h2 = { h, h };
        alpha = reinterpret_cast<const uint32_t &>( h2 );
    } else if( dtype == DATA_TYPE_BF16 ) {
        __nv_bfloat16 x = __float2bfloat16( norm );
        uint16_t h = reinterpret_cast<const uint16_t &>( x );
        ushort2 h2 = { h, h };
        alpha = reinterpret_cast<const uint32_t &>( h2 );
    } else if( dtype == DATA_TYPE_FP32 ) {
        alpha = reinterpret_cast<const uint32_t &>( norm );
    } else if( dtype == DATA_TYPE_INT32 ) {
        int32_t inorm = static_cast<int32_t>( norm );
        alpha = reinterpret_cast<const uint32_t &>( inorm );
    } else {
        assert( false );
    }
}