src/QuantUtils.cc (706 lines of code) (raw):

#define FBGEMM_EXPORTS #include <algorithm> #include <iterator> #include <numeric> #include <type_traits> #include "fbgemm/QuantUtils.h" #include <cpuinfo.h> #include "fbgemm/Fbgemm.h" #include "fbgemm/Types.h" namespace fbgemm { using namespace std; // Use fp16_min as the small scale cutoff because we don't want to use scales in // fp16 subnormal range. This is to be consistent with Glow and FakeLowP // implementation for NNPI. constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; float TensorQuantizationParams::Min() const { return Dequantize(0, *this); } float TensorQuantizationParams::Max() const { return Dequantize((1 << precision) - 1, *this); } TensorQuantizationParams ChooseQuantizationParams( float min, float max, int32_t qmin, int32_t qmax, bool preserve_sparsity, bool force_scale_power_of_two) { if (min < 0 && max > 0 && preserve_sparsity) { int symmetric_qmin = -((qmax - qmin) / 2 + 1); int symmetric_qmax = (qmax - qmin) / 2; double max_scale = std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax)); min = max_scale * symmetric_qmin; max = max_scale * symmetric_qmax; } // We extend the [min, max] interval to ensure that it contains 0. // Otherwise, we would not meet the requirement that 0 be an exactly // representable value. min = std::min(min, 0.f); max = std::max(max, 0.f); // Use double precision for intermediate computation but use single precision // in final number to reflect the actual number used during quantization. float scale = (static_cast<double>(max) - min) / (qmax - qmin); // If scale is 0 or too small so its reciprocal is infinity, we arbitrary // adjust the scale to 0.1 . We want to avoid scale's reciprocal being // infinity because some of fbgemm code pre-computes scale's reciprocal to do // multiplication instead of division in the time critical part of code. if (scale == 0.0f || isinf(1.0f / scale)) { scale = 0.1; } assert(scale > 0); if (force_scale_power_of_two) { if (scale < 1) { scale = 1.0 / (1 << static_cast<int>(floor(log2(1.0 / scale)))); } else { scale = 1 << static_cast<int>(ceil(log2(scale))); } } // Cut off small scale if (scale < SMALL_SCALE_THRESHOLD) { float org_scale = scale; scale = SMALL_SCALE_THRESHOLD; // Adjust the min and max based on the new scale if (min == 0.0f) { max = SMALL_SCALE_THRESHOLD * (qmax - qmin); } else if (max == 0.0f) { min = -SMALL_SCALE_THRESHOLD * (qmax - qmin); } else { float amplifier = SMALL_SCALE_THRESHOLD / org_scale; min *= amplifier; max *= amplifier; } } // Zero-point computation. // First the initial floating-point computation. The zero-point can be // determined from solving an affine equation for any known pair // (real value, corresponding quantized value). // We know two such pairs: (rmin, qmin) and (rmax, qmax). // The arithmetic error on the zero point computed from either pair // will be roughly machine_epsilon * (sum of absolute values of terms) // so we want to use the variant that adds the smaller terms. double zero_point_from_min = qmin - min / static_cast<double>(scale); double zero_point_from_max = qmax - max / static_cast<double>(scale); double zero_point_from_min_error = std::abs(qmin) + std::abs(min / static_cast<double>(scale)); double zero_point_from_max_error = std::abs(qmax) + std::abs(max / static_cast<double>(scale)); double initial_zero_point = zero_point_from_min_error < zero_point_from_max_error ? zero_point_from_min : zero_point_from_max; // Note: preserve_sparsity here means symmetric quantization. // for symmetric quantization, we force zero_point // to be a middle value between qmin and qmax. // If either min or max is 0, then we just use 0 as zero_point. if (min < 0 && max > 0 && preserve_sparsity) { initial_zero_point = static_cast<double>(qmin + qmax) / 2; } // Now we need to nudge the zero point to be an integer // (our zero points are integer, and this is motivated by the requirement // to be able to represent the real value "0" exactly as a quantized value, // which is required in multiple places, for example in Im2col with zero // padding). int32_t nudged_zero_point = 0; if (initial_zero_point < qmin) { nudged_zero_point = qmin; } else if (initial_zero_point > qmax) { nudged_zero_point = qmax; } else { nudged_zero_point = nearbyint(initial_zero_point); } TensorQuantizationParams result; result.scale = scale; result.zero_point = nudged_zero_point; return result; } void ChooseRequantizationMultiplier( float real_multiplier, int32_t* quantized_multiplier, int* right_shift, int requantization_multiplier_precision) { assert(real_multiplier != 0.f); // Assuming requantization_multiplier_precision_ = 31, // the default right shift is 31 when the real multiplier is already // in interval [1/2, 1). // Multiplying a 32-bit signed integer with all 31 bits except the sign bit // is used followed by 31-bit right shift implements multiplying with a real // number in [1/2, 1). // We want to utilize all 31 bits except the sign bit in the 32-bit signed // integer to get the best accuracy. int s = 31; // We want to bring the real multiplier into the interval [1/2, 1). // We can do so by multiplying it by two, and recording how many times // we multiplied by two so that we can compensate that by a right // shift by the same amount. if (real_multiplier > 0.f) { while (real_multiplier < 0.5f) { real_multiplier *= 2.f; s++; } while (real_multiplier > 1.f) { real_multiplier /= 2.f; s--; } } // Now that the real multiplier is in [1/2, 1), we convert it // into a fixed-point number. int64_t q = nearbyint( real_multiplier * (1ll << (requantization_multiplier_precision - 1))); assert(q <= (1ll << (requantization_multiplier_precision - 1))); // Handle the special case when the real multiplier was so close to 1 // that its fixed-point approximation was undistinguishable from 1. // We handle this by dividing it by two, and remembering to decrement // the right shift amount. if (q == (1ll << (requantization_multiplier_precision - 1))) { q /= 2; s--; } assert(s >= 0); assert(q >= 0); assert(q <= numeric_limits<int32_t>::max()); *quantized_multiplier = static_cast<int32_t>(q); *right_shift = s; assert(s < 64); } //////////////////////////////////////////////////////////////////////////////// // Utility functions #define FBGEMM_SPECIALIZED_QUANTIZE(T, LEGACY) \ template <> \ FBGEMM_API void Quantize<T, LEGACY>( \ const float* src, \ T* dst, \ const int64_t len, \ const TensorQuantizationParams& qparams, \ int thread_id, \ int num_threads) { \ int64_t i_begin, i_end; \ fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ for (int64_t i = i_begin; i < i_end; ++i) { \ dst[i] = Quantize<T, LEGACY>(src[i], qparams); \ } \ } FBGEMM_SPECIALIZED_QUANTIZE(uint16_t, true) FBGEMM_SPECIALIZED_QUANTIZE(int16_t, true) FBGEMM_SPECIALIZED_QUANTIZE(int32_t, true) FBGEMM_SPECIALIZED_QUANTIZE(uint16_t, false) FBGEMM_SPECIALIZED_QUANTIZE(int16_t, false) FBGEMM_SPECIALIZED_QUANTIZE(int32_t, false) #undef FBGEMM_SPECIALIZED_QUANTIZE #define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T, LEGACY) \ template <> \ FBGEMM_API void Quantize<T, LEGACY>( \ const float* src, \ T* dst, \ int64_t len, \ const TensorQuantizationParams& qparams, \ int thread_id, \ int num_threads) { \ bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \ bool fma_support = cpuinfo_has_x86_fma3(); \ int64_t i_begin, i_end; \ fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ if (avx2_support && fma_support && qparams.precision == 8) { \ /* fast path */ \ QuantizeAvx2<T, LEGACY>( \ &src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \ } else { \ for (int64_t i = i_begin; i < i_end; ++i) { \ dst[i] = Quantize<T, LEGACY>(src[i], qparams); \ } \ } \ } FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t, true) FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, true) FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t, false) FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, false) #undef FBGEMM_SPECIALIZED_QUANTIZE_AVX2 #define FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(T) \ template <> \ FBGEMM_API void FusedQuantizeDequantize<T>( \ const float* src, \ float* dst, \ int64_t len, \ const TensorQuantizationParams& qparams, \ int thread_id, \ int num_threads, \ float noise_ratio) { \ bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \ bool fma_support = cpuinfo_has_x86_fma3(); \ int64_t i_begin, i_end; \ fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ if (avx2_support && fma_support && qparams.precision == 8) { \ /* fast path */ \ FusedQuantizeDequantizeAvx2<T>( \ &src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \ } else if (noise_ratio <= 0.0f) { \ for (int64_t i = i_begin; i < i_end; ++i) { \ dst[i] = FusedQuantizeDequantize<T>(src[i], qparams); \ } \ } else { \ throw std::runtime_error("Failed to initialize cpuinfo!"); \ } \ } FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(int8_t) FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(uint8_t) #undef FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2 #define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(T) \ template <> \ FBGEMM_API void QuantizeGroupwise<T, layout_t::KCX>( \ const float* src, \ int N, \ int C, \ int X, \ int G, \ const float* scales, \ const std::int32_t* zero_points, \ T* dst) { \ assert(C % G == 0); \ int C_per_G = C / G; \ for (int i = 0; i < N; ++i) { \ for (int g = 0; g < G; ++g) { \ float scale = scales[g]; \ int32_t zero_point = zero_points[g]; \ for (int c = 0; c < C / G; ++c) { \ for (int x = 0; x < X; ++x) { \ dst[(i * C + g * C_per_G + c) * X + x] = Quantize<T>( \ src[(i * C + g * C_per_G + c) * X + x], \ zero_point, \ scale, \ 8 * sizeof(T)); \ } \ } \ } \ } \ } FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int8_t) FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int32_t) #undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX template <> FBGEMM_API void QuantizeGroupwise<uint8_t, layout_t::KCX>( const float* src, int K, int C, int X, int G, const float* scales, const std::int32_t* zero_points, uint8_t* dst) { assert(C % G == 0); int C_per_G = C / G; fbgemm::TensorQuantizationParams qparams; qparams.precision = 8 * sizeof(uint8_t); bool takeFastPath = cpuinfo_initialize() && fbgemmHasAvx2Support() && cpuinfo_has_x86_fma3(); for (int i = 0; i < K; ++i) { for (int g = 0; g < G; ++g) { qparams.scale = scales[g]; qparams.zero_point = zero_points[g]; if (takeFastPath) { QuantizeAvx2( src + (i * C + g * C_per_G) * X, dst + (i * C + g * C_per_G) * X, C_per_G * X, qparams); } else { for (int c = 0; c < C / G; ++c) { for (int x = 0; x < X; ++x) { dst[(i * C + g * C_per_G + c) * X + x] = Quantize<uint8_t>( src[(i * C + g * C_per_G + c) * X + x], qparams.zero_point, qparams.scale, qparams.precision); } } } } } } #define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(T) \ template <> \ FBGEMM_API void QuantizeGroupwise<T, layout_t::KXC>( \ const float* src, \ int K, \ int C, \ int X, \ int G, \ const float* scales, \ const std::int32_t* zero_points, \ T* dst) { \ assert(C % G == 0); \ int C_per_G = C / G; \ for (int i = 0; i < K; ++i) { \ for (int x = 0; x < X; ++x) { \ for (int g = 0; g < G; ++g) { \ float scale = scales[g]; \ int32_t zero_point = zero_points[g]; \ for (int c = 0; c < C / G; ++c) { \ dst[(i * X + x) * C + g * C_per_G + c] = Quantize<T>( \ src[(i * X + x) * C + g * C_per_G + c], \ zero_point, \ scale, \ 8 * sizeof(T)); \ } \ } \ } \ } \ } FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int8_t) FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(uint8_t) FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int32_t) #undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC //////////////////////////////////////////////////////////////////////////////// // Requantization (pure fixed-point) int64_t SaturatingRoundingMulWithShift(int32_t a, int32_t b, int right_shift) { int64_t a_64(a); int64_t b_64(b); int64_t ab_64 = a_64 * b_64; int64_t nudge = 1ll << (right_shift - 1); return (ab_64 + nudge) >> right_shift; } #define FBGEMM_SPECIALIZED_REQUANTIZE(T) \ template <> \ FBGEMM_API void Requantize<T>( \ const int32_t* src, \ T* dst, \ const int64_t len, \ const RequantizationParams& params, \ int thread_id, \ int num_threads) { \ int64_t i_begin, i_end; \ fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ for (int64_t i = i_begin; i < i_end; ++i) { \ dst[i] = Requantize<T>(src[i], params); \ } \ } FBGEMM_SPECIALIZED_REQUANTIZE(uint16_t) FBGEMM_SPECIALIZED_REQUANTIZE(int32_t) #undef FBGEMM_SPECIALIZED_REQUANTIZE template <> FBGEMM_API void Requantize<uint8_t>( const int32_t* src, uint8_t* dst, const int64_t len, const RequantizationParams& params, int thread_id, int num_threads) { int64_t i_begin, i_end; fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); if (params.target_qparams.precision == 8 && cpuinfo_initialize() && fbgemmHasAvx2Support()) { RequantizeAvx2(&src[i_begin], &dst[i_begin], i_end - i_begin, params); } else { for (int64_t i = i_begin; i < i_end; ++i) { dst[i] = Requantize<uint8_t>(src[i], params); } } } template <typename T> FBGEMM_API void RequantizeFixedPoint( const std::int32_t* src, T* dst, int64_t len, const RequantizationParams& params, int thread_id, int num_threads) { int64_t i_begin, i_end; fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); if (std::is_same<T, uint8_t>::value && params.target_qparams.precision == 8 && cpuinfo_initialize() && fbgemmHasAvx2Support()) { RequantizeFixedPointAvx2( &src[i_begin], &dst[i_begin], i_end - i_begin, params); } else { for (int64_t i = i_begin; i < i_end; ++i) { dst[i] = RequantizeFixedPoint<T>(src[i], params); } } } #define FBGEMM_SPECIALIZED_REQUANTIZE(T) \ template <> \ FBGEMM_API void RequantizeFixedPoint<T>( \ const int32_t* src, \ T* dst, \ const int64_t len, \ const RequantizationParams& params, \ int thread_id, \ int num_threads) { \ int64_t i_begin, i_end; \ fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ for (int64_t i = i_begin; i < i_end; ++i) { \ dst[i] = RequantizeFixedPoint<T>(src[i], params); \ } \ } FBGEMM_SPECIALIZED_REQUANTIZE(uint16_t) FBGEMM_SPECIALIZED_REQUANTIZE(int32_t) #undef FBGEMM_SPECIALIZED_REQUANTIZE template <> FBGEMM_API void RequantizeFixedPoint<uint8_t>( const int32_t* src, uint8_t* dst, const int64_t len, const RequantizationParams& params, int thread_id, int num_threads) { int64_t i_begin, i_end; fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); if (params.target_qparams.precision == 8 && cpuinfo_initialize() && fbgemmHasAvx2Support()) { RequantizeFixedPointAvx2( &src[i_begin], &dst[i_begin], i_end - i_begin, params); } else { for (int64_t i = i_begin; i < i_end; ++i) { dst[i] = RequantizeFixedPoint<uint8_t>(src[i], params); } } } template <typename InputType> void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef( int bit_rate, const InputType* input, size_t input_rows, int input_columns, std::uint8_t* output) { static_assert( std::is_same<InputType, float>() || std::is_same<InputType, float16>(), "Only float and float16 types are allowed."); int num_elem_per_byte = 8 / bit_rate; int output_columns = (input_columns + num_elem_per_byte - 1) / num_elem_per_byte + 2 * sizeof(float16); std::vector<float> input_row_float(input_columns); for (size_t row = 0; row < input_rows; ++row) { const InputType* input_row = input + row * input_columns; std::uint8_t* output_row = output + row * output_columns; float16* output_row_scale_bias = reinterpret_cast<float16*>( output_row + (input_columns + num_elem_per_byte - 1) / num_elem_per_byte); // NOTE: this can be optimized, however we don't care much about performance // for reference implementation. for (int col = 0; col < input_columns; ++col) { if (std::is_same<InputType, float>()) { input_row_float[col] = input_row[col]; } else { input_row_float[col] = cpu_half2float(input_row[col]); } } float minimum_element = *std::min_element(input_row_float.begin(), input_row_float.end()); float maximum_element = *std::max_element(input_row_float.begin(), input_row_float.end()); // Truncate since bias will be represented by fp16. Keep higher precision // max untouched. float16 minimum_element_fp16 = cpu_float2half_rn(minimum_element); minimum_element = cpu_half2float(minimum_element_fp16); const float range = maximum_element - minimum_element; float scale = range == 0 ? 1.0f : range / ((1 << bit_rate) - 1); float16 scale_fp16 = cpu_float2half_rn(scale); scale = cpu_half2float(scale_fp16); if (scale == 0) { // Corner case handling when maximum_element == minimum_element // Any scale would work because X - minimum_element will be 0 for all X scale = 1.0f; } float inverse_scale = 1.0f / scale; if (std::isinf(inverse_scale)) { scale = 1.0f; inverse_scale = 1.0f; } output_row_scale_bias[0] = cpu_float2half_rn(scale); output_row_scale_bias[1] = minimum_element_fp16; for (int col = 0; col < input_columns; ++col) { float X = input_row_float[col]; std::uint8_t quantized = std::max( 0, std::min<int>( std::lrintf((X - minimum_element) * inverse_scale), (1 << bit_rate) - 1)); if (col % num_elem_per_byte == 0) { output_row[col / num_elem_per_byte] = quantized; } else { output_row[col / num_elem_per_byte] |= (quantized << ((col % num_elem_per_byte) * bit_rate)); } } } // for each row } template <typename InputType> void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( int bit_rate, const InputType* input, size_t input_rows, int input_columns, std::uint8_t* output) { // Currenlty we can only dequantize if the number of input columns // is a multiple of number of elements_per_byte int num_elem_per_byte = 8 / bit_rate; if (input_columns % num_elem_per_byte != 0) { throw std::runtime_error("Unsupported number of columns"); } if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { switch (bit_rate) { case 2: FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 2>( input, input_rows, input_columns, output); break; case 4: FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 4>( input, input_rows, input_columns, output); break; case 8: FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 8>( input, input_rows, input_columns, output); break; default: FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>( bit_rate, input, input_rows, input_columns, output); } } else { FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>( bit_rate, input, input_rows, input_columns, output); } } template <typename InputType> void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef( const InputType* input, size_t input_rows, int input_columns, std::uint8_t* output) { constexpr float kEpsilon = 1e-8f; int output_columns = input_columns + 2 * sizeof(float); std::vector<float> input_row_float(input_columns); for (size_t row = 0; row < input_rows; ++row) { const InputType* input_row = input + row * input_columns; std::uint8_t* output_row = output + row * output_columns; float* output_row_scale_bias = reinterpret_cast<float*>(output_row + input_columns); for (int col = 0; col < input_columns; ++col) { if (std::is_same<InputType, float>()) { input_row_float[col] = input_row[col]; } else { input_row_float[col] = cpu_half2float(input_row[col]); } } float minimum_element = *std::min_element(input_row_float.begin(), input_row_float.end()); float maximum_element = *std::max_element(input_row_float.begin(), input_row_float.end()); float range = maximum_element - minimum_element; output_row_scale_bias[0] = range / 255.0f; output_row_scale_bias[1] = minimum_element; const auto inverse_scale = 255.0f / (range + kEpsilon); for (int col = 0; col < input_columns; ++col) { output_row[col] = std::lrintf((input_row_float[col] - minimum_element) * inverse_scale); } } // for each row } template <typename InputType> void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( const InputType* input, size_t input_rows, int input_columns, std::uint8_t* output) { if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2<InputType>( input, input_rows, input_columns, output); } else { FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef<InputType>( input, input_rows, input_columns, output); } } template <typename OutputType> void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( int bit_rate, const uint8_t* input, size_t input_rows, int input_columns, OutputType* output) { static_assert( std::is_same<OutputType, float>() || std::is_same<OutputType, float16>(), "Only float and float16 types are allowed."); int num_elem_per_byte = 8 / bit_rate; int output_columns = (input_columns - 2 * sizeof(float16)) * num_elem_per_byte; for (size_t row = 0; row < input_rows; ++row) { const std::uint8_t* input_row = input + row * input_columns; const float16* input_row_scale_bias = reinterpret_cast<const float16*>( input_row + (output_columns + num_elem_per_byte - 1) / num_elem_per_byte); float scale = cpu_half2float(input_row_scale_bias[0]); float bias = cpu_half2float(input_row_scale_bias[1]); OutputType* output_row = output + row * output_columns; for (int col = 0; col < output_columns; ++col) { std::uint8_t quantized = input_row[col / num_elem_per_byte]; quantized >>= (col % num_elem_per_byte) * bit_rate; quantized &= (1 << bit_rate) - 1; float output_value = scale * quantized + bias; if (std::is_same<OutputType, float>()) { output_row[col] = output_value; } else { output_row[col] = cpu_float2half_rn(output_value); } } } } template <typename OutputType> void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( int bit_rate, const uint8_t* input, size_t input_rows, int input_columns, OutputType* output) { if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { switch (bit_rate) { case 2: FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 2>( input, input_rows, input_columns, output); break; case 4: FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 4>( input, input_rows, input_columns, output); break; case 8: FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 8>( input, input_rows, input_columns, output); break; default: FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<OutputType>( bit_rate, input, input_rows, input_columns, output); } } else { FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<OutputType>( bit_rate, input, input_rows, input_columns, output); } } template <typename OutputType> void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef( const std::uint8_t* input, size_t input_rows, int input_columns, OutputType* output) { int output_columns = input_columns - 2 * sizeof(float); for (size_t row = 0; row < input_rows; ++row) { const std::uint8_t* input_row = input + row * input_columns; const float* input_row_scale_bias = reinterpret_cast<const float*>(input_row + output_columns); OutputType* output_row = output + row * output_columns; for (int col = 0; col < output_columns; ++col) { float output_value = input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1]; if (std::is_same<OutputType, float>()) { output_row[col] = output_value; } else { output_row[col] = cpu_float2half_rn(output_value); } } } } template <typename OutputType> void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( const std::uint8_t* input, size_t input_rows, int input_columns, OutputType* output) { if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<OutputType>( input, input_rows, input_columns, output); } else { Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<OutputType>( input, input_rows, input_columns, output); } } #define INSTANTIATE_QuantizationFunctions(type) \ template FBGEMM_API void \ FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<type>( \ int bit_rate, \ const type* input, \ size_t input_rows, \ int input_columns, \ std::uint8_t* output); \ template FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<type>( \ int bit_rate, \ const type* input, \ size_t input_rows, \ int input_columns, \ std::uint8_t* output); \ template FBGEMM_API void \ FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<type>( \ int bit_rate, \ const uint8_t* input, \ size_t input_rows, \ int input_columns, \ type* output); \ template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<type>( \ int bit_rate, \ const uint8_t* input, \ size_t input_rows, \ int input_columns, \ type* output); \ template FBGEMM_API void \ FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef<type>( \ const type* input, \ size_t input_rows, \ int input_columns, \ std::uint8_t* output); \ template FBGEMM_API void \ FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<type>( \ const type* input, \ size_t input_rows, \ int input_columns, \ std::uint8_t* output); \ template FBGEMM_API void \ Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<type>( \ const uint8_t* input, \ size_t input_rows, \ int input_columns, \ type* output); \ template FBGEMM_API void \ Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<type>( \ const uint8_t* input, \ size_t input_rows, \ int input_columns, \ type* output); // clang-format off INSTANTIATE_QuantizationFunctions(float) INSTANTIATE_QuantizationFunctions(float16) // clang-format on #undef INSTANTIATE_QuantizationFunctions } // namespace fbgemm