include/fbgemm/QuantUtils.h (224 lines of code) (raw):
#pragma once
#include "./FbgemmBuild.h"
#include "./QuantUtilsAvx2.h"
#include "./Types.h"
#include "./Utils.h"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <limits>
namespace fbgemm {
FBGEMM_API TensorQuantizationParams ChooseQuantizationParams(
float min,
float max,
std::int32_t qmin,
std::int32_t qmax,
bool preserve_sparsity = false,
bool force_scale_power_of_two = false);
FBGEMM_API void ChooseRequantizationMultiplier(
float real_multiplier,
std::int32_t* quantized_multiplier,
int* right_shift,
int requantization_multiplier_precision = 32);
////////////////////////////////////////////////////////////////////////////////
// Utility functions
// Clamp src in T1 to the desired precision and convert it to T2
// TODO: T26263653 fix signed-integer-overflow undefined behavior
template <typename T1, typename T2 = std::uint8_t>
NO_SANITIZE("signed-integer-overflow")
T2 clamp(T1 src, int precision, bool is_signed = false) {
std::int32_t min = is_signed ? -(1LL << (precision - 1)) : 0;
std::int32_t max =
is_signed ? ((1LL << (precision - 1)) - 1) : (1LL << precision) - 1;
// Make sure T1 and T2 can represent the precision
assert(min >= std::numeric_limits<T1>::lowest());
assert(min >= std::numeric_limits<T2>::lowest());
assert(max <= std::numeric_limits<T1>::max());
assert(max <= std::numeric_limits<T2>::max());
return std::min<T1>(std::max<T1>(src, min), max);
}
/// Quantize src using zero_point and scale, clamp to the specified precision,
/// and convert it to type T
template <typename T, bool LEGACY = true>
T Quantize(
float src,
std::int32_t zero_point,
float scale,
int result_precision,
bool result_is_signed = std::is_signed<T>::value) {
// Note: We want to multiply with src with inv_scale instead of
// dividing src by scale. The same is done in vector code and
// at other places.
//
// Example:
// With scale = 0.00214854861f, zero_point = 0 and src = 0.273939937f
// transformed_val is 127.5 for src * inv_scale while
// transformed_val is 127.499992 for src / scale.
// Eventually 127.5 gets rounded to 128 while 127.499992 gets rounded to 127.
float inv_scale = 1.0f / scale;
float transformed_val = src * inv_scale;
// nearbyint here performs round-to-nearest-ties-to-even with
// default rounding mode.
// For example, nearbyint(1.4) is 1.0, nearbyint(1.5) is 2.0
// and nearbyint(2.5) is 2.0
// Adding zero_point before or after rounding can make a difference
// in exactly halfway cases.
if (LEGACY) {
transformed_val = std::nearbyint(zero_point + transformed_val);
} else {
transformed_val = zero_point + std::nearbyint(transformed_val);
}
// Please note the use of double. Unlike float, a double can represent
// all int32 values exactly. Using a float results in a float value >
// INT32_MAX conversion to int32 in clamp function and hence an UBSAN error.
return clamp<double, T>(transformed_val, result_precision, result_is_signed);
}
template <typename T, bool LEGACY = true>
T Quantize(float src, const TensorQuantizationParams& qparams) {
return Quantize<T, LEGACY>(
src, qparams.zero_point, qparams.scale, qparams.precision);
}
template <typename T, bool LEGACY = true>
FBGEMM_API void Quantize(
const float* src,
T* dst,
std::int64_t len,
const TensorQuantizationParams& qparams,
int thread_id = 0,
int num_threads = 1);
/*
* @brief Quantize floating point data in src to type T
*
* @tparam T output quantized data type (int8_t, uint8_t and int32_t are
* supported)
*
* @tparam T LAYOUT layout of input tensor in src. (KCX and KXC are supported)
* KCX corresponds to KCRS or KCTRS (for weight tensors with
* time dimension)
* KXC corresponds to KRSC or KTRSC (for weight tensors with
* time dimension)
*
* @param K Output channels for weight tensors
* @param C Number of channels
* @param X R*S or T*R*S
* @param G Groups (if G == C the function performs channelwise quantization;
* if 1 < G < C the function performs groupwise quantization;
* if G == 1 the function performs per tensor quantization;)
* @param scales floating point scales.
* Size should be equal G
* @param zero_points zero points (should be reprsentable in type T).
* Size should be equal G
*/
template <typename T, layout_t LAYOUT = layout_t::KCX>
FBGEMM_API void QuantizeGroupwise(
const float* src,
int K,
int C,
int X,
int G,
const float* scales,
const std::int32_t* zero_points,
T* dst);
template <typename T>
float Dequantize(T src, const TensorQuantizationParams& qparams) {
return qparams.scale * (src - qparams.zero_point);
}
template <typename T>
void Dequantize(
const T* src,
float* dst,
std::int64_t len,
const TensorQuantizationParams& qparams,
int thread_id = 0,
int num_threads = 1) {
int64_t i_begin, i_end;
fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
for (int64_t i = i_begin; i < i_end; i++) {
dst[i] = Dequantize(src[i], qparams);
}
}
template <typename T>
float FusedQuantizeDequantize(
float src,
const TensorQuantizationParams& qparams) {
T q = Quantize<T, false>(
src, qparams.zero_point, qparams.scale, qparams.precision);
return Dequantize<T>(q, qparams);
}
/*
Fused integer quantization dequantization kernel to accelerate
quantization-aware training. Quantize fp32 values in src to (u)int8 using the
provided qparams, and dequantize quantized integer values back into fp32.
*/
template <typename T>
FBGEMM_API void FusedQuantizeDequantize(
const float* src,
float* dst,
std::int64_t len,
const TensorQuantizationParams& qparams,
int thread_id = 0,
int num_threads = 1,
float noise_ratio = 0.0f);
////////////////////////////////////////////////////////////////////////////////
// Requantization (pure fixed-point)
FBGEMM_API std::int64_t
SaturatingRoundingMulWithShift(std::int32_t a, std::int32_t b, int right_shift);
template <typename T>
T Requantize(
std::int32_t src, // int32 input before requantization
std::int32_t zero_point,
std::int32_t multiplier,
int right_shift,
int result_precision,
bool result_is_signed = false) {
std::int64_t quantized_down =
zero_point + SaturatingRoundingMulWithShift(src, multiplier, right_shift);
return clamp<std::int64_t, T>(
quantized_down, result_precision, result_is_signed);
}
template <typename T>
T RequantizeFixedPoint(
std::int32_t src, // int32 input before requantization
const RequantizationParams& params) {
return Requantize<T>(
src,
params.target_qparams.zero_point,
params.multiplier,
params.right_shift,
params.target_qparams.precision);
}
template <typename T>
FBGEMM_API void RequantizeFixedPoint(
const std::int32_t* src,
T* dst,
std::int64_t len,
const RequantizationParams& params,
int thread_id = 0,
int num_threads = 1);
////////////////////////////////////////////////////////////////////////////////
// Requantization (with floats)
template <typename T>
T Requantize(
std::int32_t src, // int32 input before requantization
std::int32_t zero_point,
float multiplier,
int result_precision,
bool result_is_signed = false) {
long quantized_down = zero_point + std::lrintf(src * multiplier);
return clamp<long, T>(quantized_down, result_precision, result_is_signed);
}
template <typename T>
T Requantize(
std::int32_t src, // int32 input before requantization
const RequantizationParams& params) {
return Requantize<T>(
src,
params.target_qparams.zero_point,
params.real_multiplier,
params.target_qparams.precision);
}
template <typename T>
FBGEMM_API void Requantize(
const std::int32_t* src,
T* dst,
std::int64_t len,
const RequantizationParams& params,
int thread_id = 0,
int num_threads = 1);
/**
* Convert float (fp32 or fp16) inputs to rowwise quantized outputs.
* bitrate specifies the number of bits in quantized output.
* Scale and Bias are in fp16. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* @param bit_rate can be 2, 4, or 8
*/
template <typename InputType>
FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
int bit_rate,
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
/**
* Convert fused rowwise quantized inputs to float (fp32 or fp16).
* bitrate specifies the number of bits in quantized input.
* Scale and Bias are in fp16. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* @param bit_rate can be 2, 4, or 8
*/
template <typename OutputType>
FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
int bit_rate,
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
/**
* Convert float or half inputs to rowwise quantized (8-bit) outputs.
* Scale and Bias are in float. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* This version intentionally supports only 8-bit because we want to discourage
* the usage of float scale and bias with 2 and 4 bit cases as that diminishes
* the overall memory savings.
*/
template <typename InputType>
FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
/**
* Convert fused rowwise quantized (8-bit) inputs to float or half outputs.
* Scale and Bias are in float. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* This version intentionally supports only 8-bit because
* the corresponding quantize version only supports 8-bit.
*/
template <typename OutputType>
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
/**
* Same as ToFusedNBitRowwiseQuantizedSBHalf but unoptimized.
* This should not be called directly except in testing.
*/
template <typename InputType>
FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef(
int bit_rate,
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
/**
* Same as FloatOrHalfToFused8BitRowwiseQuantizedSBFloat but unoptimized.
* This should not be called directly except in testing.
*/
template <typename InputType>
FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef(
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
/**
* Same as FusedNBitRowwiseQuantizedSBHalfToFloat but unoptimized.
* This should not be called directly except in testing.
*/
template <typename OutputType>
FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
int bit_rate,
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
/**
* Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized.
* This should not be called directly except in testing.
*/
template <typename OutputType>
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
} // namespace fbgemm