in aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp [1402:3601]
void do_avg_pool_nhwc_on_AVX_n(
const typename T::underlying* i_p,
typename T::underlying* o_p,
int& c_start,
int input_zero_point_m_size,
int output_zero_point,
float multiplier,
int dstart,
int dend,
int hstart,
int hend,
int wstart,
int wend,
int dsize,
int hsize,
int wsize,
int csize) {
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
// buffer for channel accumulator, used to interchange channel-loop
// to inner-most, so that memory access of the input tensor data is
// continuous.
#ifdef CPU_CAPABILITY_AVX2
constexpr int cb_size = 16;
#else
constexpr int cb_size = 8;
#endif
constexpr int vec_width = Vectorized<T>::size() / 4;
constexpr int cb_step = cb_size * vec_width;
Vectorized<int32_t> acc_buffer[cb_size];
Vectorized<float> acc_buffer_fp[cb_size];
#ifdef CPU_CAPABILITY_AVX2
if (vec_width == 8) {
#else
if (vec_width == 16) {
#endif
for (int c = c_start; c < csize; c += cb_step) {
int cend = std::min(cb_size, (csize - c) / vec_width);
// initialize loop
for (const auto ic : c10::irange(cend)) {
acc_buffer[ic] = Vectorized<int32_t>(input_zero_point_m_size);
}
// compute loop
for (const auto id : c10::irange(dstart, dend)) {
for (const auto ih : c10::irange(hstart, hend)) {
for (const auto iw : c10::irange(wstart, wend)) {
const int i_idx =
(id * wsize * hsize + ih * wsize + iw) *
csize +
c;
for (const auto ic : c10::irange(cend)) {
auto vals = vec::convert_to_int32<typename T::underlying>(
i_p + i_idx + ic * vec_width);
acc_buffer[ic] = acc_buffer[ic] + vals;
}
}
}
}
// convert int32 accumulative to fp32
vec::convert((int*)acc_buffer, (float*)acc_buffer_fp, cend * vec_width);
// first quantize using AVX2 or AVX512 using 32 lanes, then 8, finally falls
// back to single
#ifdef CPU_CAPABILITY_AVX2
QuantizeAvx2<T>(
(float*)acc_buffer_fp,
o_p + c,
cend * vec_width,
multiplier,
output_zero_point);
#else
QuantizeAvx512<T>(
(float*)acc_buffer_fp,
o_p + c,
cend * vec_width,
multiplier,
output_zero_point);
#endif
}
c_start = csize / vec_width * vec_width;
}
#endif
}
template <typename T>
void do_avg_pool_on_AVX_n(
typename T::underlying* i_p,
typename T::underlying* o_p,
int64_t& c,
int64_t channel_size,
int64_t channel_multiplier,
int32_t input_zero_point_m_size,
int32_t output_zero_point,
float multiplier,
int64_t dstart,
int64_t dend,
int64_t hstart,
int64_t hend,
int64_t wstart,
int64_t wend,
int64_t stride_C,
int64_t stride_D,
int64_t stride_H,
int64_t stride_W) {
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
constexpr int vec_width = Vectorized<T>::size() / 4;
#ifdef CPU_CAPABILITY_AVX2
if (vec_width == 8) {
#else
if (vec_width == 16) {
#endif
for (; c + vec_width <= channel_size; c += vec_width) {
int64_t tcntr = 0;
Vectorized<int32_t> acc(input_zero_point_m_size);
for (const auto id : c10::irange(dstart, dend)) {
for (const auto ih : c10::irange(hstart, hend)) {
for (const auto iw : c10::irange(wstart, wend)) {
tcntr = id * stride_D + ih * stride_H + iw * stride_W;
auto vals = vec::convert_to_int32<typename T::underlying>(
i_p + tcntr * channel_multiplier + c * stride_C);
acc = acc + vals;
}
}
}
int32_t acc_int[vec_width];
float acc_fp[vec_width];
acc.store(acc_int);
vec::convert(acc_int, acc_fp, vec_width);
at::native::quantize_vec<T>(
1.0f / multiplier,
output_zero_point,
acc_fp,
reinterpret_cast<T*>(o_p + c),
vec_width);
}
}
#endif
}
template <typename T>
void _qadaptive_avg_pool_kernel(
const Tensor& qx,
Tensor& qy,
int64_t nBatch,
int64_t sizeC,
int64_t isizeD, // Set to 1 for 2d
int64_t isizeH,
int64_t isizeW,
int64_t osizeD, // Set to 1 for 2d
int64_t osizeH,
int64_t osizeW,
int64_t istrideB,
int64_t istrideC,
int64_t istrideD, // Set to 1 for 2d
int64_t istrideH,
int64_t istrideW) {
T* idata = static_cast<T*>(qx.data_ptr());
T* odata = static_cast<T*>(qy.data_ptr());
const float input_scale = qx.q_scale();
const float output_scale = qy.q_scale();
const int input_zero_point = qx.q_zero_point();
const int output_zero_point = qy.q_zero_point();
at::parallel_for(0, nBatch, 0, [&](int64_t batch_start, int64_t batch_end) {
for (const auto b : c10::irange(batch_start, batch_end)) {
auto* i_p = reinterpret_cast<typename T::underlying*>(
idata + b * istrideB);
for (const auto od : c10::irange(osizeD)) {
int istartD = (int)std::floor((float)(od * isizeD) / osizeD);
int iendD = (int)std::ceil((float)((od + 1) * isizeD) / osizeD);
int kD = iendD - istartD;
for (const auto oh : c10::irange(osizeH)) {
int istartH = (int)std::floor((float)(oh * isizeH) / osizeH);
int iendH = (int)std::ceil((float)((oh + 1) * isizeH) / osizeH);
int kH = iendH - istartH;
for (const auto ow : c10::irange(osizeW)) {
auto* o_p = reinterpret_cast<typename T::underlying*>(
odata +
b * osizeD * osizeH * osizeW * sizeC +
od * osizeH * osizeW * sizeC +
oh * osizeW * sizeC +
ow * sizeC);
int istartW = (int)std::floor((float)(ow * isizeW) / osizeW);
int iendW = (int)std::ceil((float)((ow + 1) * isizeW) / osizeW);
int kW = iendW - istartW;
int size = kD * kH * kW;
float multiplier = input_scale / output_scale / size;
int input_zero_point_m_size = -input_zero_point * size;
int64_t c = 0;
// For int8 or uint8quantization, we implicitly use int32 as
// accumulation Or else, it will go to the slow path
// TODO: support 16bit, 32bit, and etc.
auto* internal_i_p = i_p +
istartD * istrideD +
istartH * istrideH +
istartW * istrideW;
// Note: If AVX is not available, `do_avg_pool_on_AVX_n is a noop.
// In that case, the following loop takes over
// TODO: more vectorization with loop interleaving
do_avg_pool_on_AVX_n<T>(
internal_i_p,
o_p,
c,
sizeC,
1,
input_zero_point_m_size,
output_zero_point,
multiplier,
0,
kD,
0,
kH,
0,
kW,
istrideC,
istrideD,
istrideH,
istrideW);
// 1) The following loop handles the remaining channels
// 2) It also handles the Non-AVX2 path
for (; c < sizeC; ++c) {
int32_t acc_int32 = input_zero_point_m_size;
int64_t tcntr = 0;
for (const auto id : c10::irange(kD)) {
for (const auto ih : c10::irange(kH)) {
for (const auto iw : c10::irange(kW)) {
tcntr = id * istrideD +
ih * istrideH +
iw * istrideW;
auto val = *(internal_i_p + tcntr + c * istrideC);
acc_int32 += val;
}
}
}
// clamp
o_p[c] = at::native::quantize_val<T>(1.0f / multiplier,
output_zero_point,
acc_int32).val_;
} // c
} // oh
} // ow
} // od
}
});
}
void qadaptive_avg_pool2d_nhwc_kernel(
const Tensor& qx,
Tensor& qy,
int64_t nBatch,
int64_t sizeC,
int64_t isizeH,
int64_t isizeW,
int64_t osizeH,
int64_t osizeW,
int64_t istrideB,
int64_t istrideC,
int64_t istrideH,
int64_t istrideW) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "adaptive_avg_pool2d_nhwc", [&]() {
_qadaptive_avg_pool_kernel<scalar_t>(
qx,
qy,
nBatch,
sizeC,
/*isizeD=*/1,
isizeH,
isizeW,
/*osizeD=*/1,
osizeH,
osizeW,
istrideB,
istrideC,
/*istrideD=*/1,
istrideH,
istrideW);
}
);
}
void qadaptive_avg_pool3d_ndhwc_kernel(
const Tensor& qx,
Tensor& qy,
int64_t nBatch,
int64_t sizeC,
int64_t isizeD,
int64_t isizeH,
int64_t isizeW,
int64_t osizeD,
int64_t osizeH,
int64_t osizeW,
int64_t istrideB,
int64_t istrideC,
int64_t istrideD,
int64_t istrideH,
int64_t istrideW) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "adaptive_avg_pool3d_ndhwc", [&]() {
_qadaptive_avg_pool_kernel<scalar_t>(
qx,
qy,
nBatch,
sizeC,
isizeD,
isizeH,
isizeW,
osizeD,
osizeH,
osizeW,
istrideB,
istrideC,
istrideD,
istrideH,
istrideW);
}
);
}
template <typename T>
void _qavg_pool_nhwc_kernel(
const Tensor& qx,
Tensor& qy,
int64_t nBatch,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t inputDepth,
int64_t outputWidth,
int64_t outputHeight,
int64_t outputDepth,
int kW,
int kH,
int kD,
int dW,
int dH,
int dD,
int padW,
int padH,
int padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
T* idata = static_cast<T*>(qx.data_ptr());
T* odata = static_cast<T*>(qy.data_ptr());
int strideC = 1;
int strideW = strideC * nInputPlane;
int istrideH = strideW * inputWidth;
int istrideD = istrideH * inputHeight;
int istrideB = istrideD * inputDepth;
int ostrideH = strideW * outputWidth;
int ostrideD = ostrideH * outputHeight;
int ostrideB = ostrideD * outputDepth;
// lift these operations outside the loop to reduce access overheads
float input_scale = qx.q_scale();
float output_scale = qy.q_scale();
int input_zero_point = qx.q_zero_point();
int output_zero_point = qy.q_zero_point();
int64_t divisor_override_factor =
divisor_override.has_value() ? divisor_override.value() : 0;
at::parallel_for(0, nBatch, 0, [&](int64_t batch_start, int64_t batch_end) {
for (int64_t b = batch_start; b < batch_end; ++b) {
auto* i_p =
reinterpret_cast<typename T::underlying*>(idata + b * istrideB);
for (int od = 0; od < outputDepth; od++) {
for (int oh = 0; oh < outputHeight; oh++) {
for (int ow = 0; ow < outputWidth; ow++) {
auto* o_p = reinterpret_cast<typename T::underlying*>(
odata + b * ostrideB + od * ostrideD + oh * ostrideH +
ow * strideW);
int dstart = od * dD - padD;
int hstart = oh * dH - padH;
int wstart = ow * dW - padW;
int dend = std::min(dstart + kD, (int)inputDepth + padD);
int hend = std::min(hstart + kH, (int)inputHeight + padH);
int wend = std::min(wstart + kW, (int)inputWidth + padW);
int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
dstart = std::max(dstart, 0);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
dend = std::min(dend, (int)inputDepth);
hend = std::min(hend, (int)inputHeight);
wend = std::min(wend, (int)inputWidth);
int size = (dend - dstart) * (hend - hstart) * (wend - wstart);
int divide_size = count_include_pad ? pool_size : size;
int divide_factor =
divisor_override_factor ? divisor_override_factor : divide_size;
float multiplier = input_scale / output_scale / divide_factor;
int input_zero_point_m_size = -input_zero_point * size;
int c_start = 0;
// For int8 quantization, we implicitly use int32 as accumulation
// Or else, it will go to the slow path
// TODO: support 16bit, 32bit, and etc.
do_avg_pool_nhwc_on_AVX_n<T>(
i_p,
o_p,
c_start,
input_zero_point_m_size,
output_zero_point,
multiplier,
dstart,
dend,
hstart,
hend,
wstart,
wend,
inputDepth,
inputHeight,
inputWidth,
nInputPlane);
// 1) The following loop handles the remaining channels
// 2) It also handles the Non-AVX2 path
for (int c = c_start; c < nInputPlane; ++c) {
int32_t acc_int32 = input_zero_point_m_size;
for (int64_t id = dstart; id < dend; id++) {
for (int64_t ih = hstart; ih < hend; ih++) {
for (int64_t iw = wstart; iw < wend; iw++) {
auto val =
*(i_p + id * istrideD + ih * istrideH + iw * strideW +
c * strideC);
acc_int32 += val;
}
}
}
double acc_fp = acc_int32 * 1.0;
// clamp
o_p[c] = at::native::quantize_val<T>(
1.0f / multiplier, output_zero_point, acc_fp)
.val_;
} // c
} // ow
} // oh
} // od
}
});
}
void qavg_pool2d_nhwc_kernel(
const Tensor& qx,
Tensor& qy,
int64_t b,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t outputWidth,
int64_t outputHeight,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool2d_nhwc", [&]() {
_qavg_pool_nhwc_kernel<scalar_t>(
qx,
qy,
b,
nInputPlane,
inputWidth,
inputHeight,
1,
outputWidth,
outputHeight,
1,
kW,
kH,
1,
dW,
dH,
1,
padW,
padH,
0,
count_include_pad,
divisor_override);
}
);
}
void qavg_pool3d_nhwc_kernel(
const Tensor& qx,
Tensor& qy,
int64_t b,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t inputDepth,
int64_t outputWidth,
int64_t outputHeight,
int64_t outputDepth,
int kW,
int kH,
int kD,
int dW,
int dH,
int dD,
int padW,
int padH,
int padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool3d_nhwc", [&]() {
_qavg_pool_nhwc_kernel<scalar_t>(
qx,
qy,
b,
nInputPlane,
inputWidth,
inputHeight,
inputDepth,
outputWidth,
outputHeight,
outputDepth,
kW,
kH,
kD,
dW,
dH,
dD,
padW,
padH,
padD,
count_include_pad,
divisor_override);
}
);
}
template <typename T>
int64_t do_quantized_bilinear_on_AVX_n(
const typename T::underlying*& pos1,
typename T::underlying*& pos2,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width,
int64_t channels,
int32_t output_zero_point,
int32_t input_zero_point,
float inverse_scale,
const float h0lambda,
const float h1lambda,
const float w0lambda,
const float w1lambda,
const int64_t h1p,
const int64_t w1p) {
int64_t c = 0;
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
constexpr auto vec_width = Vectorized<T>::size() / 4;
#ifdef CPU_CAPABILITY_AVX2
if (vec_width == 8) {
#else
if (vec_width == 16) {
#endif
for (; c + vec_width <= channels; c += vec_width) {
Vectorized<float> pos1_fp_v[4];
Vectorized<int32_t> pos1_int_v[4];
pos1_int_v[0] = vec::convert_to_int32<typename T::underlying>(pos1);
pos1_int_v[1] = vec::convert_to_int32<typename T::underlying>(
pos1 + w1p * channels);
pos1_int_v[2] = vec::convert_to_int32<typename T::underlying>(
pos1 + h1p * input_width * channels);
pos1_int_v[3] = vec::convert_to_int32<typename T::underlying>(
pos1 + (h1p * input_width + w1p) * channels);
for (const auto i : c10::irange(4)) {
int32_t pos1_int[vec_width];
float pos1_fp[vec_width];
pos1_int_v[i].store(pos1_int);
vec::convert(pos1_int, pos1_fp, vec_width);
pos1_fp_v[i] = Vectorized<float>::loadu(pos1_fp, 8);
}
Vectorized<float> h0lambda_v(h0lambda);
Vectorized<float> h1lambda_v(h1lambda);
Vectorized<float> w0lambda_v(w0lambda);
Vectorized<float> w1lambda_v(w1lambda);
Vectorized<float> input_zero_point_v(input_zero_point);
Vectorized<float> result =
h0lambda_v * (w0lambda_v * pos1_fp_v[0] + w1lambda_v * pos1_fp_v[1]) +
h1lambda_v * (w0lambda_v * pos1_fp_v[2] + w1lambda_v * pos1_fp_v[3]) -
input_zero_point_v;
float result_fp[vec_width];
result.store(result_fp);
at::native::quantize_vec<T>(
inverse_scale,
output_zero_point,
result_fp,
reinterpret_cast<T*>(pos2),
vec_width);
pos1 += vec_width;
pos2 += vec_width;
}
}
#endif
return c;
}
void qupsample_bilinear2d_nhwc_kernel(
Tensor& output,
const Tensor& input,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width,
int64_t nbatch,
int64_t channels,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
AT_DISPATCH_QINT_TYPES(
input.scalar_type(), "upsample_bilinear2d_nhwc", [&]() {
auto* idata = static_cast<scalar_t*>(input.data_ptr());
auto* odata = static_cast<scalar_t*>(output.data_ptr());
float inverse_scale = output.q_scale() / input.q_scale();
const auto rheight = area_pixel_compute_scale<float>(
input_height, output_height, align_corners, scales_h);
const auto rwidth = area_pixel_compute_scale<float>(
input_width, output_width, align_corners, scales_w);
const int64_t input_q_zero_point = input.q_zero_point();
const int64_t output_q_zero_point = output.q_zero_point();
for (const auto b : c10::irange(nbatch)) {
auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(
idata + b * input_height * input_width * channels);
auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(
odata + b * output_height * output_width * channels);
for (const auto h2 : c10::irange(output_height)) {
const auto h1r = area_pixel_compute_source_index<float>(
rheight, h2, align_corners, /*cubic=*/false);
const int64_t h1 = h1r;
const int64_t h1p = (h1 < input_height - 1) ? 1 : 0;
const float h1lambda = h1r - h1;
const float h0lambda = static_cast<float>(1.) - h1lambda;
for (const auto w2 : c10::irange(output_width)) {
const auto w1r = area_pixel_compute_source_index<float>(
rwidth, w2, align_corners, /*cubic=*/false);
const int64_t w1 = w1r;
const int64_t w1p = (w1 < input_width - 1) ? 1 : 0;
const float w1lambda = w1r - w1;
const float w0lambda = static_cast<float>(1.) - w1lambda;
int64_t c = 0;
// We use float32 to do the computation
const typename scalar_t::underlying* pos1 =
i_p + (h1 * input_width + w1) * channels;
typename scalar_t::underlying* pos2 =
o_p + (h2 * output_width + w2) * channels;
// We have to isolate this function out because the VS does not
// expand the macro correctly.
c = do_quantized_bilinear_on_AVX_n<scalar_t>(
pos1,
pos2,
input_height,
input_width,
output_height,
output_width,
channels,
output_q_zero_point,
input_q_zero_point,
inverse_scale,
h0lambda,
h1lambda,
w0lambda,
w1lambda,
h1p,
w1p);
// 1) The following loop handles the remaining channels
// 2) It also handles the Non-AVX2 path
for (; c < channels; ++c) {
float result = h0lambda *
(w0lambda * pos1[0] + w1lambda * pos1[w1p * channels]) +
h1lambda *
(w0lambda * pos1[h1p * input_width * channels] +
w1lambda * pos1[(h1p * input_width + w1p) * channels]);
pos2[0] = at::native::quantize_val<scalar_t>(
inverse_scale,
output_q_zero_point,
result - input_q_zero_point)
.val_;
pos1 += 1;
pos2 += 1;
} // c
} // w2
} // h2
} // b
});
}
void qtopk_kernel(Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim,
bool largest,
bool sorted) {
auto sizes = self.sizes();
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(sizes, /*squash_dims=*/dim)
.add_output(values)
.add_output(indices)
.add_input(self)
.build();
auto mode_values_stride = values.strides()[dim];
auto mode_indices_stride = indices.strides()[dim];
auto tmp_values_stride = self.strides()[dim];
AT_DISPATCH_QINT_TYPES(self.scalar_type(), "qtopk_cpu", [&] {
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
using underlying_t = typename scalar_t::underlying;
static_assert(sizeof(scalar_t) == sizeof(underlying_t), "");
return topk_impl_loop<underlying_t, underlying_t>(
mode_values_stride, mode_indices_stride, tmp_values_stride,
k, sizes[dim], largest, sorted, data, strides, n);
};
int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, sizes[dim]);
iter.for_each(loop, /*grain_size=*/grain_size);
});
}
template <typename T>
inline void do_bn_compute(
typename T::underlying* X_ptr,
typename T::underlying* Y_ptr,
Vectorized<float> & fake_scale,
Vectorized<float> & in_zp_vec,
Vectorized<float> & scale_neg_zp_premul,
int64_t out_zero_point,
Vectorized<T> & out_zero_point_v,
float* alpha,
float* beta,
int64_t vec_num,
bool ReluFused,
int64_t kVLen
) {
using Vec = Vectorized<T>;
auto vals_q = Vec::loadu(X_ptr);
// Fake scale of 1.0 here, should not affect performance (FMA in place of sub)
auto vals_dq = vals_q.dequantize(fake_scale, in_zp_vec, scale_neg_zp_premul);
for (const auto idx : c10::irange(vec_num)) {
auto alpha_v = Vectorized<float>::loadu(alpha + idx * kVLen);
auto beta_v = Vectorized<float>::loadu(beta + idx * kVLen);
vals_dq[idx] = vec::fmadd(alpha_v, vals_dq[idx], beta_v);
}
// NOLINTNEXTLINE(bugprone-argument-comment)
auto outputs_q = Vec::quantize(vals_dq, /*output_scale=*/1.0f, out_zero_point, /*inv_output_scale=*/1.0f);
// Fake scale again
if (ReluFused) {
outputs_q = outputs_q.maximum(out_zero_point_v);
}
outputs_q.store(Y_ptr, vec_num * kVLen);
}
template <bool ReluFused>
void q_batch_norm_kernel(
int64_t N,
int64_t C,
int64_t HxW,
int64_t in_zero_point,
int64_t out_zero_point,
const Tensor& input,
const Tensor& a,
const Tensor& b,
Tensor& output) {
AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qbatch_norm", [&]() {
float* alpha = a.data_ptr<float>();
float* beta = b.data_ptr<float>();
auto minimum = std::numeric_limits<scalar_t::underlying>::lowest();
auto maximum = std::numeric_limits<scalar_t::underlying>::max();
scalar_t::underlying* X =
reinterpret_cast<scalar_t::underlying*>(input.data_ptr());
scalar_t::underlying* Y = reinterpret_cast<scalar_t::underlying*>(output.data_ptr());
constexpr int kVLen = Vectorized<float>::size();
const int64_t outer_size = N * HxW;
using Vec = Vectorized<scalar_t>;
// Hoisted variables
auto in_zp_vec = Vectorized<float>(static_cast<float>(in_zero_point));
auto fake_scale = Vectorized<float>(1.0f);
auto scale_neg_zp_premul = fake_scale * in_zp_vec.neg();
auto out_zero_point_v = Vec(scalar_t(out_zero_point));
const auto lanes = static_cast<int64_t>(Vec::float_num_vecs() * kVLen);
for (const auto i : c10::irange(outer_size)) {
auto* X_ptr = reinterpret_cast<typename scalar_t::underlying*>(X + i * C);
auto* Y_ptr = reinterpret_cast<typename scalar_t::underlying*>(Y + i * C);
int64_t ch = 0;
for(; ch + lanes <= C; ch += lanes ) {
do_bn_compute<scalar_t>(
X_ptr + ch,
Y_ptr + ch,
fake_scale,
in_zp_vec,
scale_neg_zp_premul,
out_zero_point,
out_zero_point_v,
alpha + ch,
beta + ch,
Vec::float_num_vecs(),
ReluFused,
kVLen
);
}
// for channel between 8 and 32, still use 32 width for performance
// Benchmark shows it is faster than doing 8 channels each time
int64_t elem_size = C - ch;
if ((lanes == 32) && elem_size >= kVLen) {
int64_t vec_num = elem_size / kVLen;
std::vector<typename scalar_t::underlying> buf_in(lanes);
memcpy(buf_in.data(), X_ptr + ch, vec_num * kVLen); // 3 cycles
do_bn_compute<scalar_t>(
buf_in.data(),
Y_ptr + ch,
fake_scale,
in_zp_vec,
scale_neg_zp_premul,
out_zero_point,
out_zero_point_v,
alpha + ch,
beta + ch,
vec_num,
ReluFused,
kVLen
);
ch += vec_num * kVLen;
}
// for channels less than 8
for (; ch < C; ++ch) {
long quantized_down = out_zero_point +
lrintf(alpha[ch] * (X_ptr[ch] - in_zero_point) +
beta[ch]);
if (ReluFused) { // static if
quantized_down = std::max<long>(quantized_down, out_zero_point);
}
Y_ptr[ch] = std::min<long>(
std::max<long>(quantized_down, minimum), maximum);
}
}
});
}
void _fake_quantize_tensor_helper(
Tensor& output,
Tensor& mask,
const Tensor& input,
int fake_quant_on,
float sc,
int64_t z_point,
int64_t quant_min,
int64_t quant_max) {
float inv_scale = 1.0f / sc;
auto iter_combined = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_output(mask)
.add_input(input)
.build();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_type_handling", [&] {
iter_combined.for_each([&](char** data, const int64_t* strides, int64_t n) {
for (const auto i : c10::irange(n)) {
scalar_t* output_val = (scalar_t*)(data[0] + i * strides[0]);
bool* mask_val = (bool*)(data[1] + i * strides[1]);
scalar_t* input_val = (scalar_t*)(data[2] + i * strides[2]);
const auto qval = static_cast<int64_t>(z_point + std::nearbyint(*input_val * inv_scale));
if (fake_quant_on) {
*output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc;
*mask_val = ((quant_min <= qval) && (qval <= quant_max));
} else {
*output_val = *input_val;
*mask_val = 1;
}
}
});
});
}
void fake_quantize_tensor_cachemask_kernel(
Tensor& output,
Tensor& mask,
const Tensor& input,
float sc,
int64_t z_point,
int64_t quant_min,
int64_t quant_max) {
_fake_quantize_tensor_helper(output, mask, input, 1, sc, z_point, quant_min, quant_max);
}
void fake_quantize_tensor_cachemask_tensor_qparams_kernel(
Tensor& output,
Tensor& mask,
const Tensor& input,
const Tensor& sc,
const Tensor& z_point,
const Tensor& fake_quant_enabled,
int64_t quant_min,
int64_t quant_max) {
_fake_quantize_tensor_helper(output, mask, input, fake_quant_enabled.item().toInt(), sc.item().toFloat(), z_point.item().toInt(), quant_min, quant_max);
}
void fake_quantize_learnable_tensor_grad_kernel_cpu(
TensorIterator& iter,
float scale,
float inv_scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max,
float grad_factor) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float dscale_small = quant_min - zero_point;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float dscale_big = quant_max - zero_point;
iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
/* When a for_each call is made on a TensorIterator with multiple inputs and outputs,
the order they are accessed follows the order they are built within the iterator.
For example, if an iterator is built in the following order:
auto iter = TensorIteratorConfig().
.add_output(firstOutput)
.add_output(secondOutput)
.add_input(firstInput)
.add_input(secondInput)
.build()
data will contain 4 pointers to pointers to values in the following order:
firstOutput, secondOutput, firstInput, secondInput.
Proper pointer referencing and dereferencing, along with the usage of strides
(to move onto different elements), can allow accessing of the input and assignment
to the right output.
*/
for (const auto i : c10::irange(n)) {
float* dXOutput = (float*)(data[0] + i * strides[0]);
float* dScaleOutput = (float*)(data[1] + i * strides[1]);
float* dZeroPointOutput = (float*)(data[2] + i * strides[2]);
float* XInput = (float*)(data[3] + i * strides[3]);
float* dYInput = (float*)(data[4] + i * strides[4]);
// Calculate gradients for X.
int64_t xqi = std::nearbyint(zero_point + (*XInput) * inv_scale);
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
*dXOutput = (*dYInput) * (xqi >= quant_min && xqi <= quant_max);
// Calculate gradients for scale and zero point.
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float xfqi = static_cast<float>((std::max(std::min(xqi, quant_max), quant_min) - zero_point) * scale);
// Calculate gradients according to the gradient of the clamp function.
if (xqi < quant_min || xqi > quant_max) {
*dZeroPointOutput = (*dYInput) * (-1) * scale * grad_factor;
*dScaleOutput = ((xqi < quant_min) ? ((*dYInput) * dscale_small) : ((*dYInput) * dscale_big)) * grad_factor;
} else {
*dZeroPointOutput = 0;
*dScaleOutput = (*dYInput) * (xfqi - (*XInput)) * inv_scale * grad_factor;
}
}
});
}
template <typename SelfType>
void _fake_quant_per_channel_cachemask_cpu_helper(
TensorIterator& iter,
TensorIterator& iter_mask,
const int64_t quant_min,
const int64_t quant_max) {
const auto& zero_point_dtype = iter.input_dtype(2);
if(at::isFloatingType(zero_point_dtype)){
// When zero_point is float, quantize mirroring affine quantizer equation
// Xq = Round(Xf * inv_scale + zero_point)
// where zero_point is in float.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(zero_point_dtype, "fake_quantize_channel_cachemask_cpu_zero_point_handling", [&] {
// write mask
cpu_kernel(iter_mask, [=](SelfType self, float scale, scalar_t zero_point) -> bool {
float inv_scale = 1.0f / scale;
const auto qval = std::lrintf(zero_point + (self * inv_scale));
return ((quant_min <= qval) && (qval <= quant_max));
});
// write fake_quant
cpu_kernel(iter, [=](SelfType self, float scale, scalar_t zero_point) -> SelfType {
float inv_scale = 1.0f / scale;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
return (std::fmin(
std::fmax(
std::lrintf(zero_point + self * inv_scale),
quant_min),
quant_max) -
zero_point) *
scale;
});
});
} else {
// write mask
cpu_kernel(iter_mask, [=](SelfType self, float scale, int32_t zero_point) -> bool {
float inv_scale = 1.0f / scale;
const auto qval = static_cast<int64_t>(zero_point + std::nearbyint(self * inv_scale));
return ((quant_min <= qval) && (qval <= quant_max));
});
// write fake_quant
cpu_kernel(iter, [=](SelfType self, float scale, int32_t zero_point) -> SelfType {
float inv_scale = 1.0f / scale;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
return (std::fmin(
std::fmax(
static_cast<int64_t>(
zero_point + std::nearbyint(self * inv_scale)),
quant_min),
quant_max) -
zero_point) *
scale;
});
}
}
void fake_quant_per_channel_cachemask_cpu(
TensorIterator& iter,
TensorIterator& iter_mask,
int64_t quant_min,
int64_t quant_max) {
// TODO(future, optional): read once, write twice. Not done at the moment
// for simplicity, as we do not expect this to be a bottleneck.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "fake_quantize_channel_cachemask_cpu_type_handling", [&] {
_fake_quant_per_channel_cachemask_cpu_helper<scalar_t>(iter, iter_mask, quant_min, quant_max);
});
}
void fake_quantize_learnable_channel_grad_kernel_cpu(
TensorIterator& iter,
int64_t quant_min,
int64_t quant_max,
float grad_factor) {
iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
/* To see how the input and outputs are referenced and assigned,
please see the implemenetation of
fake_quantize_learnable_tensor_grad_kernel_cpu.
*/
for (const auto i : c10::irange(n)) {
float* dx_output = (float*)(data[0] + i * strides[0]);
float* dscale_output = (float*)(data[1] + i * strides[1]);
float* dzero_point_output = (float*)(data[2] + i * strides[2]);
float* x_input = (float*)(data[3] + i * strides[3]);
float* dy_input = (float*)(data[4] + i * strides[4]);
float* scale_input = (float*)(data[5] + i * strides[5]);
float* zero_point_input = (float*)(data[6] + i * strides[6]);
float inv_scale = 1.0f / (*scale_input);
float dscale_small = quant_min - (*zero_point_input);
float dscale_big = quant_max - (*zero_point_input);
// Calculate gradients for X.
int64_t xqi = std::nearbyint((*zero_point_input) + (*x_input) * inv_scale);
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
*dx_output = (*dy_input) * (xqi >= quant_min && xqi <= quant_max);
// Calculate gradients for scale and zero point.
float xfqi = static_cast<float>((std::max(std::min(xqi, quant_max), quant_min) - (*zero_point_input)) * (*scale_input));
if (xqi < quant_min || xqi > quant_max) {
*dzero_point_output = (*dy_input) * (-1) * (*scale_input) * grad_factor;
*dscale_output = ((xqi < quant_min) ? ((*dy_input) * dscale_small) : ((*dy_input) * dscale_big)) * grad_factor;
} else {
*dzero_point_output = 0;
*dscale_output = (*dy_input) * (xfqi - (*x_input)) * inv_scale * grad_factor;
}
}
});
}
// Assumes X is composed of M groups of N elements. Normalizes each of the
// groups and optionally applies affine scaling. Useful for LayerNorm,
// GroupNorm, InstanceNorm.
void quantized_normalize_kernel(
const Tensor& X, // input tensor
const Tensor& gamma, // weight (optional)
const Tensor& beta, // bias (optional)
bool affine_per_channel, // scaling applied elementwise if false, per channel if true
int num_channels, // only used if affine_per_channel is set
int num_groups, // only used if affine_per_channel is set
int64_t M, // number of groups
int64_t N, // number of elements in each group
double eps,
Tensor* Y) {
AT_DISPATCH_QINT_TYPES(X.scalar_type(), "quantized_layer_norm_kernel_impl_cpu", [&]() {
using qVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
TORCH_INTERNAL_ASSERT(X.numel() == M * N, "Unexpected num elements in X");
TORCH_INTERNAL_ASSERT(
!gamma.defined() ||
(!affine_per_channel && gamma.numel() == N) ||
(affine_per_channel && gamma.numel() == num_channels),
"Unexpected size of gamma");
TORCH_INTERNAL_ASSERT(
!beta.defined() ||
(!affine_per_channel && beta.numel() == N) ||
(affine_per_channel && beta.numel() == num_channels),
"Unexpected size of beta");
scalar_t* X_data = X.data_ptr<scalar_t>();
const float* gamma_data = gamma.defined() ? gamma.data_ptr<float>() : nullptr;
const float* beta_data = beta.defined() ? beta.data_ptr<float>() : nullptr;
scalar_t* Y_data = Y->data_ptr<scalar_t>();
const bool gamma_null = gamma_data == nullptr;
const bool beta_null = beta_data == nullptr;
int64_t x_zp = X.q_zero_point();
float x_scale = X.q_scale();
fVec x_zp_vec((float)x_zp);
fVec one_vec(1.0f);
fVec zero_vec(0.0f);
float x_fake_scale = 1.0f;
fVec x_fake_scale_vec(x_fake_scale);
fVec x_fake_scale_zp_neg_premul_vec = x_fake_scale_vec * x_zp_vec.neg();
int64_t y_zp = Y->q_zero_point();
float y_scale = Y->q_scale();
float y_inv_scale = 1.0f / y_scale;
constexpr int kFloatVLen = fVec::size();
int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
int64_t kNumIntVecInLayer = N / kIntVLen;
int64_t kNonVecRemInLayer = N % kIntVLen;
int channels_per_group = num_channels / num_groups;
int64_t NPerChannel = N / channels_per_group;
int64_t kNumIntVecInChannel = NPerChannel / kIntVLen;
int64_t kNonVecRemInChannel = NPerChannel % kIntVLen;
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
scalar_t* X_ptr = X_data + i * N;
scalar_t* Y_ptr = Y_data + i * N;
// First pass: calculate mean and variance.
scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
auto l_sum_shifted = hsum(X_ptr_underlying, N);
auto l_sum_sq_shifted = hsum_sq(X_ptr_underlying, N);
float l_mean_shifted_div_scale_x = static_cast<float>(l_sum_shifted) / N;
// mean(dqX) / scale_x
float layer_mean_div_scale_x = l_mean_shifted_div_scale_x - x_zp;
// var(dqX) / scale_x^2
float layer_var_div_scale_x_sq =
std::max(static_cast<float>(l_sum_sq_shifted) / N -
l_mean_shifted_div_scale_x * l_mean_shifted_div_scale_x, 0.0f);
// scale_x / sqrt(var(dqX) + eps)
float scale_x_div_layer_std = x_scale /
std::sqrt(layer_var_div_scale_x_sq * x_scale * x_scale + eps);
fVec layer_mean_div_scale_xVec(layer_mean_div_scale_x);
fVec scale_x_div_layer_stdVec(scale_x_div_layer_std);
// Second pass: normalize
// TODO replace with TensorIterator implementation once #33166 is fixed.
if (affine_per_channel) {
// if scaling per channel, scaling parameters can be pre-multiplied
// with normalization parameters
for (const auto chIdx : c10::irange(channels_per_group)) {
int scalingIdx = (i * channels_per_group + chIdx) % (num_channels);
float gamma = gamma_null ? 1.0f : gamma_data[scalingIdx];
// scale_x / layer_std * gamma
float gamma_p = scale_x_div_layer_std * gamma;
float beta = beta_null ? 0.0f : beta_data[scalingIdx];
fVec gamma_p_vec(gamma_p);
fVec beta_vec(beta);
int64_t chStartIdx = chIdx * NPerChannel;
int64_t chEndIdx = chStartIdx + NPerChannel;
for (const auto vecIdx : c10::irange(kNumIntVecInChannel)) {
int64_t vecStartIdx = chStartIdx + vecIdx * kIntVLen;
auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
x_fake_scale_zp_neg_premul_vec);
for (auto &dq : dqXVec) {
dq =
(dq - layer_mean_div_scale_xVec) *
gamma_p_vec + beta_vec;
qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
.store(Y_ptr + vecStartIdx);
}
}
for (int64_t remIdx = chEndIdx - kNonVecRemInChannel;
remIdx < chEndIdx;
remIdx++) {
auto qXVal = X_ptr[remIdx];
float dqXVal = at::native::dequantize_val(x_fake_scale, x_zp, qXVal);
float dqY =
(dqXVal - layer_mean_div_scale_x) * gamma_p + beta;
Y_ptr[remIdx] = at::native::quantize_val<scalar_t>(y_scale, y_zp, dqY);
}
} // chIdx
} else {
for (const auto vecIdx : c10::irange(kNumIntVecInLayer)) {
int64_t vecStartIdx = vecIdx * kIntVLen;
auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
x_fake_scale_zp_neg_premul_vec);
for (const auto dqXVecIdx : c10::irange(dqXVec.size())) {
int64_t vecVecStartIdx = vecStartIdx + dqXVecIdx * kFloatVLen;
auto gammaVec = gamma_null
? one_vec
: fVec::loadu(gamma_data + vecVecStartIdx);
auto betaVec = beta_null
? zero_vec
: fVec::loadu(beta_data + vecVecStartIdx);
dqXVec[dqXVecIdx] =
(dqXVec[dqXVecIdx] - layer_mean_div_scale_xVec) *
scale_x_div_layer_stdVec * gammaVec + betaVec;
qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
.store(Y_ptr + vecStartIdx);
}
}
for (int64_t remIdx = N - kNonVecRemInLayer; remIdx < N; remIdx++) {
const float gamma_v = gamma_null ? 1.0f : gamma_data[remIdx];
const float beta_v = beta_null ? 0.0f : beta_data[remIdx];
auto qXVal = X_ptr[remIdx];
float dqXVal = at::native::dequantize_val(x_fake_scale, x_zp, qXVal);
float dqY =
((dqXVal - layer_mean_div_scale_x) * scale_x_div_layer_std) * gamma_v + beta_v;
Y_ptr[remIdx] = at::native::quantize_val<scalar_t>(y_scale, y_zp, dqY);
}
}
}
}); // parallel_for
});
}
#ifdef USE_FBGEMM
void quantize_tensor_per_tensor_affine_cpu(
const Tensor& rtensor,
Tensor& qtensor,
double scale,
int64_t zero_point) {
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
check_tensor_memory_format(rtensor, qtensor);
const float* rd = rtensor.data_ptr<float>();
auto qd = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
fbgemm::TensorQuantizationParams qparams;
qparams.scale = scale;
qparams.zero_point = zero_point;
qparams.precision = CHAR_BIT * sizeof(underlying_t);
int num_tasks = at::get_num_threads();
at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
for (const auto task_id : c10::irange(begin, end)) {
fbgemm::Quantize<underlying_t, false /*LEGACY*/>(
// NOLINTNEXTLINE(bugprone-argument-comment)
rd, /*src=*/
// NOLINTNEXTLINE(bugprone-argument-comment)
qd, /*dst=*/
rtensor.numel(), /*len*/
// NOLINTNEXTLINE(bugprone-argument-comment)
qparams, /*qparams=*/
task_id, /*thread_id*/
num_tasks /*num_threads*/);
}
});
});
}
void dequantize_tensor_per_tensor_affine_cpu(
const Tensor& qtensor,
Tensor& rtensor,
double scale,
int64_t zero_point) {
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
check_tensor_memory_format(qtensor, rtensor);
const auto* qd =
reinterpret_cast<const underlying_t*>(qtensor.data_ptr<scalar_t>());
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
fbgemm::TensorQuantizationParams qparams;
qparams.scale = scale;
qparams.zero_point = zero_point;
qparams.precision = CHAR_BIT * sizeof(underlying_t);
float* rd = rtensor.data_ptr<float>();
int num_tasks = at::get_num_threads();
at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
for (const auto task_id : c10::irange(begin, end)) {
fbgemm::Dequantize<underlying_t>(
// NOLINTNEXTLINE(bugprone-argument-comment)
qd, /*src=*/
// NOLINTNEXTLINE(bugprone-argument-comment)
rd, /*dst=*/
// NOLINTNEXTLINE(bugprone-argument-comment)
qtensor.numel(), /*len=*/
// NOLINTNEXTLINE(bugprone-argument-comment)
qparams, /*qparams=*/
task_id, /*thread_id*/
num_tasks /*num_threads*/);
}
});
});
}
#else // USE_FBGEMM
#if defined(__ARM_NEON__) || defined(__aarch64__)
const static int PARALLEL_THRESHOLD = 1 << 20;
// Generic template defaults to naive quantize implementation
template <typename T>
void quantize_tensor_arm(
const float* __restrict__ in,
T* __restrict__ out,
const int64_t N,
const float scale,
const int32_t zero_point) {
for (const auto i : c10::irange(N)) {
out[i] = at::native::quantize_val<T>(scale, zero_point, in[i]);
}
}
// Specialized implementation from caffe2::Int8Quantize.
// There may be slight accuracy difference between this and implementation of
// quantize_val
// TODO Update quantize_tensor_arm implementation to follow quantize_val,
// i.e. f = Round(value/scale + zero_point)
// TODO Make quantize_tensor_arm work for other datatypes too (int8, int32).
template <>
void quantize_tensor_arm<c10::quint8>(
const float* __restrict__ in,
c10::quint8* __restrict__ out,
const int64_t N,
const float scale,
const int32_t zero_point) {
const float inv_scale = 1.0f / scale;
uint32_t i = 0;
uint8_t* out_underlying = reinterpret_cast<uint8_t*>(out);
const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);
#if defined(__ARM_NEON__)
// magic float and magic int to take care of rounding
// int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
// Some detail:
// 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
// add a small number to a large number, the result rounds to the precision of
// the least significant bit of the large number. For IEEE-754
// single-precision number mantissa has 23 bits, and adding 2**23 would cause
// rounding to the nearest even integer. The we cast to int and subtract the
// same number (0x4B400000 is the integer representation of 12582912.0f) to
// get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
// sign for negative numbers.
const int32x4_t voffset = vdupq_n_s32(zero_point - 0x4B400000);
const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
for (i = 0; i + 8 <= N; i += 8) {
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t vraw0123 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
const int32x4_t vraw4567 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
const int16x8_t vraw01234567 =
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
vst1_u8(out_underlying, vout01234567);
out_underlying += 8;
}
for (; i < N; ++i) {
(*out_underlying++) = at::native::quantize_val_arm(scale, zero_point, (*in++));
}
#else
const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t)zero_point);
for (i = 0; i + 8 <= N; i += 8) {
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t v0123_rounded = vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
const int32x4_t v4567_rounded = vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));
const int16x8_t v01234567_packed = vqaddq_s16(
vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), vzero_point);
const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
vst1_u8(out_underlying, vout01234567);
out_underlying += 8;
}
for (; i < N; ++i) {
(*out_underlying++) = at::native::quantize_val_arm(scale, zero_point, (*in++));
}
#endif
}
#if defined(__aarch64__)
#define VMOVL_HIGH_U8(x) vmovl_high_u8(x)
#define VMOVL_HIGH_S8(x) vmovl_high_s8(x)
#define VMOVL_HIGH_U16(x) vmovl_high_u16(x)
#define VMOVL_HIGH_S16(x) vmovl_high_s16(x)
#else // vmovl_high intrinsic not supported
#define VMOVL_HIGH_U8(x) vmovl_u8(vget_high_u8(x))
#define VMOVL_HIGH_S8(x) vmovl_s8(vget_high_s8(x))
#define VMOVL_HIGH_U16(x) vmovl_u16(vget_high_u16(x))
#define VMOVL_HIGH_S16(x) vmovl_s16(vget_high_s16(x))
#endif
// Generic template defaults to naive dequantize implementation
template <typename T>
void dequantize_tensor_arm(
const T* __restrict__ in,
float* __restrict__ out,
const int64_t N,
const float scale,
const int32_t zero_point) {
for (int i = 0; i < N; ++i) {
out[i] = dequantize_val<T>(scale, zero_point, in[i]);
}
}
template <>
void dequantize_tensor_arm<c10::qint8>(
const c10::qint8* __restrict__ in,
float* __restrict__ out,
const int64_t N,
const float scale,
const int32_t zero_point) {
const int8_t* in_underlying = reinterpret_cast<const int8_t*>(in);
const float32x4_t scale_fp32x4 = vdupq_n_f32(scale);
// Zero point is restricted to be in bounds of a signed 8 bit integer
const int8x8_t zero_point_s8x8 = vget_low_s8(vdupq_n_s8(static_cast<int8_t>(zero_point)));
int i;
for (i = 0; i + 16 <= N; i += 16) {
const int8x16_t vin_s8 = vld1q_s8(in_underlying);
// Extract upper or lower values to int16x8 and subtract zero point
// Each input element and the zero point are restricted to be in bounds of
// a signed 8 bit integer, so the difference will fit in a signed 16 bit
// integer
const int16x8_t minus_zp_low_s16 = vsubl_s8(vget_low_s8(vin_s8), zero_point_s8x8); // 0 ... 7
const int16x8_t minus_zp_high_s16 = vsubl_s8(vget_high_s8(vin_s8), zero_point_s8x8); // 8 ... 15
const int32x4_t minus_zp_low_low = vmovl_s16(vget_low_s16(minus_zp_low_s16)); // 0 ... 3
const int32x4_t minus_zp_low_high = VMOVL_HIGH_S16(minus_zp_low_s16); // 4 ... 7
const int32x4_t minus_zp_high_low = vmovl_s16(vget_low_s16(minus_zp_high_s16)); // 8 ... 11
const int32x4_t minus_zp_high_high = VMOVL_HIGH_S16(minus_zp_high_s16); // 12 ... 15
// Store * scale int32->fp32
vst1q_f32(out, vmulq_f32(vcvtq_f32_s32(minus_zp_low_low), scale_fp32x4));
vst1q_f32(out + 4, vmulq_f32(vcvtq_f32_s32(minus_zp_low_high), scale_fp32x4));
vst1q_f32(out + 8, vmulq_f32(vcvtq_f32_s32(minus_zp_high_low), scale_fp32x4));
vst1q_f32(out + 12, vmulq_f32(vcvtq_f32_s32(minus_zp_high_high), scale_fp32x4));
out += 16;
in += 16;
in_underlying += 16;
}
for (; i < N; ++i) { // use default dequantize for remaining vals
(*out++) = dequantize_val<c10::qint8>(scale, zero_point, (*in++));
}
}
template <>
void dequantize_tensor_arm<c10::quint8>(
const c10::quint8* __restrict__ in,
float* __restrict__ out,
const int64_t N,
const float scale,
const int32_t zero_point) {
const uint8_t* in_underlying = reinterpret_cast<const uint8_t*>(in);
const float32x4_t scale_fp32x4 = vdupq_n_f32(scale);
// Zero point is restricted to be in bounds of an unsigned 8 bit integer
const uint8x8_t zero_point_u8x8 = vget_low_u8(vdupq_n_u8(static_cast<uint8_t>(zero_point)));
int i;
for (i = 0; i + 16 <= N; i += 16) {
const uint8x16_t vin_u8 = vld1q_u8(in_underlying);
// Extract upper or lower values to uint16x8 and subtract zero point
// Each input element and the zero point are restricted to be in bounds of
// an unsigned 8 bit integer, so the difference will fit in a signed 16 bit
// integer
const int16x8_t minus_zp_low_s16 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vin_u8), zero_point_u8x8)); // 0 ... 7
const int16x8_t minus_zp_high_s16 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vin_u8), zero_point_u8x8)); // 8 ... 15
const int32x4_t minus_zp_low_low = vmovl_s16(vget_low_s16(minus_zp_low_s16)); // 0 ... 3
const int32x4_t minus_zp_low_high = VMOVL_HIGH_S16(minus_zp_low_s16); // 4 ... 7
const int32x4_t minus_zp_high_low = vmovl_s16(vget_low_s16(minus_zp_high_s16)); // 8 ... 11
const int32x4_t minus_zp_high_high = VMOVL_HIGH_S16(minus_zp_high_s16); // 12 ... 15
// Store * scale int32->fp32
vst1q_f32(out, vmulq_f32(vcvtq_f32_s32(minus_zp_low_low), scale_fp32x4));
vst1q_f32(out + 4, vmulq_f32(vcvtq_f32_s32(minus_zp_low_high), scale_fp32x4));
vst1q_f32(out + 8, vmulq_f32(vcvtq_f32_s32(minus_zp_high_low), scale_fp32x4));
vst1q_f32(out + 12, vmulq_f32(vcvtq_f32_s32(minus_zp_high_high), scale_fp32x4));
out += 16;
in += 16;
in_underlying += 16;
}
for (; i < N; ++i) { // use default dequantize for remaining vals
(*out++) = dequantize_val<c10::quint8>(scale, zero_point, (*in++));
}
}
#endif // defined(__ARM_NEON__) || defined(__aarch64__)
void quantize_tensor_per_tensor_affine_cpu(
const Tensor& rtensor,
Tensor& qtensor,
double scale,
int64_t zero_point) {
check_tensor_memory_format(rtensor, qtensor);
const float* rdata = rtensor.data_ptr<float>();
int numel = rtensor.numel();
#if defined(__ARM_NEON__) || defined(__aarch64__)
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
scalar_t* qdata = qtensor.data_ptr<scalar_t>();
auto quantize_range = [&](int64_t begin, int64_t end) {
quantize_tensor_arm<scalar_t>(
rdata + begin, qdata + begin, end - begin, scale, zero_point);
};
if (numel >= PARALLEL_THRESHOLD) {
at::parallel_for(0, numel, 1, quantize_range);
} else {
quantize_range(0, numel);
}
});
#else
// Fallback path
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
scalar_t* qdata = qtensor.data_ptr<scalar_t>();
for (const auto i : c10::irange(numel)) {
qdata[i] = quantize_val<scalar_t>(scale, zero_point, rdata[i]);
}
});
#endif // defined(__ARM_NEON__) || defined(__aarch64__)
}
void dequantize_tensor_per_tensor_affine_cpu(
const Tensor& qtensor,
Tensor& rtensor,
double scale,
int64_t zero_point) {
check_tensor_memory_format(qtensor, rtensor);
float* rdata = rtensor.data_ptr<float>();
int numel = qtensor.numel();
#if defined(__ARM_NEON__) || defined(__aarch64__)
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
const scalar_t* qdata = qtensor.data_ptr<scalar_t>();
auto dequantize_range = [&](int64_t begin, int64_t end) {
dequantize_tensor_arm<scalar_t>(
qdata + begin, rdata + begin, end - begin, scale, zero_point);
};
if (numel >= PARALLEL_THRESHOLD) {
at::parallel_for(0, numel, 1, dequantize_range);
} else {
dequantize_range(0, numel);
}
});
#else
// Fallback path
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
const scalar_t* qdata = qtensor.data_ptr<scalar_t>();
for (const auto i : c10::irange(numel)) {
rdata[i] = dequantize_val<scalar_t>(scale, zero_point, qdata[i]);
}
});
#endif // defined(__ARM_NEON__) || defined(__aarch64__)
}
#endif // USE_FBGEMM
// TODO: add fbgemm for per channel
// Generic template defaults to naive quantize implementation
template <typename T>
void quantize_tensor_per_channel_impl(
const Tensor& rtensor,
Tensor& qtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
// TODO: channels last kernel can be made faster.
// For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
// For channels_last/3d however axis == 0 or 1.
// Since current implemntation on channels_last format does not
// cover per channel quant with arbitrary axis value, it is better
// to check and fail.
int64_t batches = size_to_dim_(axis, rtensor.sizes());
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
int64_t channels = rtensor.size(axis);
auto scales_data = scales.data_ptr<double>();
auto zero_points_data = zero_points.data_ptr<int64_t>();
const float* in = rtensor.data_ptr<float>();
auto out = qtensor.data_ptr<T>();
if (axis == 1 &&
(rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
// This code handles per channel quant when axis = 1 and
// channels_last contig.
// If axis = 0 and channels_last contig, implementation for channels
// first (NCHW) works.
for (const auto b : c10::irange(batches)) {
for (const auto e : c10::irange(elements_per_channel)) {
for (const auto c : c10::irange(channels)) {
auto i = b * channels * elements_per_channel + e * channels + c;
out[i] = at::native::quantize_val<T>(
scales_data[c], zero_points_data[c], in[i]);
}
}
}
} else {
for (const auto b : c10::irange(batches)) {
for (const auto c : c10::irange(channels)) {
for (const auto e : c10::irange(elements_per_channel)) {
auto i = b * channels * elements_per_channel +
c * elements_per_channel + e;
out[i] = at::native::quantize_val<T>(
scales_data[c], zero_points_data[c], in[i]);
}
}
}
}
}
#if defined(__ARM_NEON__) || defined(__aarch64__)
// Specialized implementation from caffe2::Int8Quantize.
// There may be slight accuracy difference between this and implementation of
// quantize_val
// TODO Update quantize_tensor_per_channel_impl implementation to follow
// quantize_val, i.e. f = Round(value/scale + zero_point)
// TODO Make quantize_tensor_per_channel_impl work for other datatypes too
// (int8, int32).
template <>
void quantize_tensor_per_channel_impl<c10::quint8>(
const Tensor& rtensor,
Tensor& qtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
int64_t channels = rtensor.size(axis);
auto scales_data = scales.data_ptr<double>();
auto zero_points_data = zero_points.data_ptr<int64_t>();
const float* in = rtensor.data_ptr<float>();
auto out = (uint8_t*)qtensor.data_ptr<c10::quint8>();
#if defined(__ARM_NEON__)
// magic float and magic int to take care of rounding
// int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
// Some detail:
// 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
// add a small number to a large number, the result rounds to the precision of
// the least significant bit of the large number. For IEEE-754
// single-precision number mantissa has 23 bits, and adding 2**23 would cause
// rounding to the nearest even integer. The we cast to int and subtract the
// same number (0x4B400000 is the integer representation of 12582912.0f) to
// get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
// sign for negative numbers.
const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
// Copy reciprocal of scales (double) into float array
// Copy zero_points with magic int (int64_t) into int32_t array
std::vector<float> inv_scales(channels);
std::vector<int32_t> zero_points_int32t(channels);
for (const auto i : c10::irange(channels)) {
inv_scales[i] = 1.0f / (float)scales_data[i];
zero_points_int32t[i] = (int32_t)(uint32_t)zero_points_data[i] - 0x4B400000;
}
if (axis == 1 &&
(rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
// This code handles per channel quant when axis = 1 and
// channels_last contig.
// If axis = 0 and channels_last contig, implementation for channels
// first (NCHW) works.
for (const auto b : c10::irange(batches)) {
for (const auto e : c10::irange(elements_per_channel)) {
uint32_t c = 0;
while (c + 8 < channels) {
const int32x4_t voffset0123 = vld1q_s32(&zero_points_int32t[c]);
const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]);
c += 4;
const int32x4_t voffset4567 = vld1q_s32(&zero_points_int32t[c]);
const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]);
c += 4;
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t vraw0123 = vaddq_s32(
voffset0123,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale0123))));
const int32x4_t vraw4567 = vaddq_s32(
voffset4567,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale4567))));
const int16x8_t vraw01234567 =
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
vst1_u8(out, vout01234567);
out += 8;
}
for (; c < channels; ++c) {
(*out++) =
at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++));
}
}
}
} else {
for (const auto b : c10::irange(batches)) {
for (const auto c : c10::irange(channels)) {
uint32_t e = 0;
const int32x4_t voffset = vdupq_n_s32(zero_points_int32t[c]);
const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]);
for (; e + 8 < elements_per_channel; e += 8) {
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t vraw0123 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
const int32x4_t vraw4567 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
const int16x8_t vraw01234567 =
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
vst1_u8(out, vout01234567);
out += 8;
}
for (; e < elements_per_channel; ++e) {
(*out++) =
at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++));
}
}
}
}
#else // defined(__ARM_NEON__)
// Copy scales (double) into float array
// Copy zero_points (int64_t) into int16_t array
std::vector<float> inv_scales(channels);
std::vector<int16_t> zero_points_int16t(channels);
for (const auto i : c10::irange(channels)) {
inv_scales[i] = 1.0f / (float)scales_data[i];
zero_points_int16t[i] = (int16_t)(uint16_t)zero_points_data[i];
}
if (axis == 1 &&
(rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
// This code handles per channel quant when axis = 1 and
// channels_last contig.
// If axis = 0 and channels_last contig, implementation for channels
// first (NCHW) works.
for (const auto b : c10::irange(batches)) {
for (const auto e : c10::irange(elements_per_channel)) {
uint32_t c = 0;
while (c + 8 < channels) {
const int16x8_t vzero_point = vld1q_s16(&zero_points_int16t[c]);
const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]);
c += 4;
const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]);
c += 4;
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t v0123_rounded =
vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale0123));
const int32x4_t v4567_rounded =
vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale4567));
const int16x8_t v01234567_packed = vqaddq_s16(
vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
vzero_point);
const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
vst1_u8(out, vout01234567);
out += 8;
}
for (; c < channels; ++c) {
(*out++) =
at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++));
}
}
}
} else {
for (const auto b : c10::irange(batches)) {
for (const auto c : c10::irange(channels)) {
uint32_t e = 0;
const int16x8_t vzero_point = vdupq_n_s16(zero_points_int16t[c]);
const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]);
for (; e + 8 < elements_per_channel; e += 8) {
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t v0123_rounded =
vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
const int32x4_t v4567_rounded =
vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));
const int16x8_t v01234567_packed = vqaddq_s16(
vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
vzero_point);
const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
vst1_u8(out, vout01234567);
out += 8;
}
for (; e < elements_per_channel; ++e) {
(*out++) =
at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++));
}
}
}
}
#endif // defined(__ARM_NEON__)
}
#endif // defined(__ARM_NEON__) || defined(__aarch64__)
void quantize_tensor_per_channel_affine_cpu(
const Tensor& rtensor,
Tensor& qtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
TORCH_CHECK(
rtensor.is_contiguous() || (axis <= 1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() {
check_tensor_memory_format(rtensor, qtensor);
quantize_tensor_per_channel_impl<scalar_t>(
rtensor, qtensor, scales, zero_points, axis);
});
}
template<typename T, typename N, typename Q>
void dequantize_per_channel_affine_kernel(
const Tensor& qtensor,
Tensor& rtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis,
int bit_width=8) {
// For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
// For channels_last/3d however axis == 0 or 1.
// Since current implemntation on channels_last format does not
// cover per channel quant with arbitrary axis value, it is better
// to check and fail.
TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
size_from_dim_(axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(axis);
auto scales_data = scales.data_ptr<T>();
auto zero_points_data = zero_points.data_ptr<N>();
check_tensor_memory_format(qtensor, rtensor);
const auto* qd = qtensor.data_ptr<Q>();
float* rd = rtensor.data_ptr<float>();
const auto elem_per_byte = 8 / bit_width;
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
for (const auto b : c10::irange(batches)) {
for (const auto e : c10::irange(elements_per_channel)) {
for (const auto c : c10::irange(channel)) {
auto i = b * channel * elements_per_channel + e * channel + c;
// We need to convert the qint8 value to float to ensure the
// subtraction subexpression returns a float
auto qvalue = qd[i / elem_per_byte].val_;
if (bit_width < 8) {
qvalue >>= (i % elem_per_byte) * bit_width;
qvalue &= (1 << bit_width) - 1;
}
rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
}
}
}
} else {
for (const auto b : c10::irange(batches)) {
for (const auto c : c10::irange(channel)) {
for (const auto e : c10::irange(elements_per_channel)) {
auto i = b * channel * elements_per_channel +
c * elements_per_channel + e;
// We need to convert the qint8 value to float to ensure the
// subtraction subexpression returns a float
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
auto qvalue = qd[i / elem_per_byte].val_;
if (bit_width < 8) {
qvalue >>= (i % elem_per_byte) * bit_width;
qvalue &= (1 << bit_width) - 1;
}
rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
}
}
}
}
}
void dequantize_tensor_per_channel_affine_cpu(
const Tensor& qtensor,
Tensor& rtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_channel_affine_cpu", [&]() {
dequantize_per_channel_affine_kernel<double, int64_t, scalar_t>(qtensor, rtensor, scales, zero_points, axis);
});
}
// quantize stubs for floating point scale and zero_point.
void quantize_tensor_per_channel_float_qparams_cpu(
const Tensor& rtensor,
Tensor& qtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
// For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
// For channels_last/3d however axis == 0 or 1.
// Since current implemntation on channels_last format does not
// cover per channel quant with arbitrary axis value, it is better
// to check and fail.
TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_channel_float_qparams_cpu", [&]() {
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel =
size_from_dim_(axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(axis);
auto scales_data = scales.data_ptr<float>();
auto zero_points_data = zero_points.data_ptr<float>();
check_tensor_memory_format(rtensor, qtensor);
const float* rdata = rtensor.data_ptr<float>();
auto qdata = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
const auto elem_per_byte = CHAR_BIT / bit_width;
int qvalue = 0;
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
for (const auto b : c10::irange(batches)) {
for (const auto e : c10::irange(elements_per_channel)) {
for (const auto c : c10::irange(channel)) {
auto i = b * channel * elements_per_channel + e * channel + c;
qvalue = quantize_val_float_qparams(
scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
if (i % elem_per_byte == 0) {
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
} else {
qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
}
}
}
}
} else {
for (const auto b : c10::irange(batches)) {
for (const auto c : c10::irange(channel)) {
for (const auto e : c10::irange(elements_per_channel)) {
auto i = b * channel * elements_per_channel +
c * elements_per_channel + e;
qvalue = quantize_val_float_qparams(
scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
if (i % elem_per_byte == 0) {
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
} else {
qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
}
}
}
}
}
});
}
void dequantize_tensor_per_channel_float_qparams_cpu(
const Tensor& qtensor,
Tensor& rtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_channel_float_qparams_cpu", [&]() {
dequantize_per_channel_affine_kernel<float, float, scalar_t>(qtensor, rtensor, scales, zero_points, axis, bit_width);
});
}
void quantize_tensor_per_tensor_affine_sub_byte_cpu(
const Tensor& rtensor,
Tensor& qtensor,
float scale,
float zero_point) {
// TODO Use fbgemm kernel to pack values
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() {
check_tensor_memory_format(rtensor, qtensor);
const float* const rdata = rtensor.data_ptr<float>();
auto qdata = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
auto numel = rtensor.numel();
const auto elem_per_byte = CHAR_BIT / bit_width;
for (const auto i : c10::irange(numel)) {
float inv_scale = scale == 0 ? 1.0f : 1.0f / scale;
int64_t qvalue = lrintf(std::nearbyint(rdata[i] * inv_scale) + zero_point);
qvalue = std::max(quant_min, std::min(qvalue, quant_max));
// We pack sub_byte values and align them to a byte.
// Eg. for 4-bits Index 0 is packed in the lower 4-bits
// and index 1 is packed in the upper 4-bits.
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
if (i % elem_per_byte == 0) {
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
} else {
qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
}
} // for numel
});
}
void dequantize_tensor_per_tensor_affine_sub_byte_cpu(
const Tensor& qtensor,
Tensor& rtensor,
float scale,
float zero_point) {
// TODO Use fbgemm kernel to pack values
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() {
check_tensor_memory_format(rtensor, qtensor);
auto rdata = rtensor.data_ptr<float>();
const underlying_t* qdata = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
auto numel = rtensor.numel();
const auto elem_per_byte = CHAR_BIT / bit_width;
for (const auto i : c10::irange(numel)) {
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
underlying_t qvalue = qdata[i / elem_per_byte];
qvalue >>= (i % elem_per_byte) * bit_width;
qvalue &= (1 << bit_width) - 1;
rdata[i] = (static_cast<float>(qvalue) - zero_point) * scale;
}
});
}
} // namespace
// Some quantization tests are flaky on Windows with AVX512. If --continue-through-error
// is used, only one fails. But if the failing test is skipped, another one fails.
// If the second test is also skipped, a third one fails.
// So, until Quantization support for Windows is fixed for AVX512,
// AVX2 kernels would be used instead. Ref: GH 56992.
#if defined(CPU_CAPABILITY_AVX512) && defined(_WIN32)
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_channel_affine_stub);
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_tensor_affine_stub);
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub);
REGISTER_NO_AVX512_DISPATCH(fake_quant_grad_learnable_tensor_stub);
REGISTER_NO_AVX512_DISPATCH(fake_quant_per_channel_cachemask_stub);
REGISTER_NO_AVX512_DISPATCH(fake_quant_tensor_cachemask_stub);
REGISTER_NO_AVX512_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub);
REGISTER_NO_AVX512_DISPATCH(qadaptive_avg_pool2d_nhwc_stub);
REGISTER_NO_AVX512_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub);
REGISTER_NO_AVX512_DISPATCH(qadd_relu_stub);
REGISTER_NO_AVX512_DISPATCH(qadd_scalar_relu_stub);
REGISTER_NO_AVX512_DISPATCH(qadd_scalar_stub);
REGISTER_NO_AVX512_DISPATCH(qadd_stub);
REGISTER_NO_AVX512_DISPATCH(qavg_pool2d_nhwc_stub);
REGISTER_NO_AVX512_DISPATCH(qavg_pool3d_nhwc_stub);
REGISTER_NO_AVX512_DISPATCH(qbatch_norm_relu_stub);
REGISTER_NO_AVX512_DISPATCH(qbatch_norm_stub);
REGISTER_NO_AVX512_DISPATCH(qcat_nhwc_stub);
REGISTER_NO_AVX512_DISPATCH(qcat_relu_nhwc_stub);
REGISTER_NO_AVX512_DISPATCH(qclamp_stub);
REGISTER_NO_AVX512_DISPATCH(qclamp_min_stub);
REGISTER_NO_AVX512_DISPATCH(qclamp_max_stub);
REGISTER_NO_AVX512_DISPATCH(qelu_stub);
REGISTER_NO_AVX512_DISPATCH(qhardsigmoid_stub);
REGISTER_NO_AVX512_DISPATCH(qhardswish_stub);
REGISTER_NO_AVX512_DISPATCH(qmaxpool_2d_nhwc_stub);
REGISTER_NO_AVX512_DISPATCH(qmul_relu_stub);
REGISTER_NO_AVX512_DISPATCH(qmul_stub);
REGISTER_NO_AVX512_DISPATCH(qrelu_leaky_stub);
REGISTER_NO_AVX512_DISPATCH(qrelu_stub);
REGISTER_NO_AVX512_DISPATCH(qgelu_stub);
REGISTER_NO_AVX512_DISPATCH(qsigmoid_stub);
REGISTER_NO_AVX512_DISPATCH(qtanh_stub);
REGISTER_NO_AVX512_DISPATCH(qthreshold_stub);
REGISTER_NO_AVX512_DISPATCH(qtopk_stub);
REGISTER_NO_AVX512_DISPATCH(fake_quant_grad_learnable_channel_stub);
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_stub);
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_affine_stub);
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_float_qparams_stub);
REGISTER_NO_AVX512_DISPATCH(quantized_normalize_stub);
REGISTER_NO_AVX512_DISPATCH(qupsample_bilinear2d_nhwc_stub);
REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_sub_byte_stub);
REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_tensor_affine_sub_byte_stub);
#else
REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub,
&dequantize_tensor_per_channel_affine_cpu);
REGISTER_DISPATCH(dequantize_tensor_per_tensor_affine_stub,
&dequantize_tensor_per_tensor_affine_cpu);
REGISTER_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
&dequantize_tensor_per_channel_float_qparams_cpu);
REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub,
&fake_quantize_learnable_tensor_grad_kernel_cpu);
REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub, &fake_quant_per_channel_cachemask_cpu);
REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub,
&fake_quantize_tensor_cachemask_kernel);
REGISTER_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub,
&fake_quantize_tensor_cachemask_tensor_qparams_kernel);
REGISTER_DISPATCH(qadaptive_avg_pool2d_nhwc_stub,
&qadaptive_avg_pool2d_nhwc_kernel);
REGISTER_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub,
&qadaptive_avg_pool3d_ndhwc_kernel);
REGISTER_DISPATCH(qadd_relu_stub, &qadd_kernel<true>);
REGISTER_DISPATCH(qadd_scalar_relu_stub, &qadd_scalar_kernel<true>);
REGISTER_DISPATCH(qadd_scalar_stub, &qadd_scalar_kernel<false>);
REGISTER_DISPATCH(qadd_stub, &qadd_kernel<false>);
REGISTER_DISPATCH(qavg_pool2d_nhwc_stub, &qavg_pool2d_nhwc_kernel);
REGISTER_DISPATCH(qavg_pool3d_nhwc_stub, &qavg_pool3d_nhwc_kernel);
REGISTER_DISPATCH(qbatch_norm_relu_stub, &q_batch_norm_kernel<true>);
REGISTER_DISPATCH(qbatch_norm_stub, &q_batch_norm_kernel<false>);
REGISTER_DISPATCH(qcat_nhwc_stub, &qcat_nhwc_kernel<false>);
REGISTER_DISPATCH(qcat_relu_nhwc_stub, &qcat_nhwc_kernel<true>);
REGISTER_DISPATCH(qclamp_stub, &qclamp_kernel);
REGISTER_DISPATCH(qclamp_min_stub, &qclamp_min_kernel);
REGISTER_DISPATCH(qclamp_max_stub, &qclamp_max_kernel);
REGISTER_DISPATCH(qelu_stub, &qelu_kernel);
REGISTER_DISPATCH(qhardsigmoid_stub, &qhardsigmoid_kernel);
REGISTER_DISPATCH(qhardswish_stub, &qhardswish_kernel);
REGISTER_DISPATCH(qmaxpool_2d_nhwc_stub, &qmaxpool_2d_nhwc_kernel);
REGISTER_DISPATCH(qmul_relu_stub, &qmul_kernel<true>);
REGISTER_DISPATCH(qmul_stub, &qmul_kernel<false>);
REGISTER_DISPATCH(qrelu_leaky_stub, &leaky_qrelu_out_kernel);
REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
REGISTER_DISPATCH(qgelu_stub, &qgelu_kernel);
REGISTER_DISPATCH(qsigmoid_stub, &qsigmoid_kernel);
REGISTER_DISPATCH(qtanh_stub, &qtanh_kernel);
REGISTER_DISPATCH(qthreshold_stub, &qthreshold_kernel);
REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel);
REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub,
&fake_quantize_learnable_channel_grad_kernel_cpu);
REGISTER_DISPATCH(
quantize_tensor_per_tensor_affine_stub,
&quantize_tensor_per_tensor_affine_cpu);
REGISTER_DISPATCH(
quantize_tensor_per_channel_affine_stub,
&quantize_tensor_per_channel_affine_cpu);
REGISTER_DISPATCH(
quantize_tensor_per_channel_float_qparams_stub,
&quantize_tensor_per_channel_float_qparams_cpu);
REGISTER_DISPATCH(quantized_normalize_stub, &quantized_normalize_kernel);
REGISTER_DISPATCH(qupsample_bilinear2d_nhwc_stub,
&qupsample_bilinear2d_nhwc_kernel);
REGISTER_DISPATCH(
quantize_tensor_per_tensor_affine_sub_byte_stub,
&quantize_tensor_per_tensor_affine_sub_byte_cpu);
REGISTER_DISPATCH(
dequantize_tensor_per_tensor_affine_sub_byte_stub,
&dequantize_tensor_per_tensor_affine_sub_byte_cpu);
#endif // CPU_CAPABILITY_AVX512 && _WIN32
} // namespace native
} // namespace at