void requantizeOutputProcessingAvx2()

in src/QuantUtilsAvx2.cc [456:986]


void requantizeOutputProcessingAvx2(
    uint8_t* out,
    const int32_t* inp,
    const block_type_t& block,
    int ld_out,
    int ld_in,
    const requantizationParams_t<BIAS_TYPE>& r) {
  // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
  // using AVX2 instructions
  int quant_param_idx = 0;
  if (Q_GRAN == QuantizationGranularity::GROUP) {
    int ncol_per_group = r.ncols / r.groups;
    int g = block.col_start / ncol_per_group;
    quant_param_idx = g;
  }
  __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);

  // Broadcasted reciprocal of act_times_w_scale
  __m256 act_times_w_rcp_v;
  if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
    if (is_same<BIAS_TYPE, float>::value) {
      act_times_w_rcp_v =
          _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
    }
  }

  __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
  __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));

  assert(
      (A_SYMMETRIC == (r.A_zero_point == 0)) &&
      "A_SYMMETRIC == true if and only if A_zero_point == 0");
  assert(
      (B_SYMMETRIC ==
       ((Q_GRAN == QuantizationGranularity::TENSOR && r.B_zero_point[0] == 0) ||
        r.row_offsets == nullptr)) &&
      "B_SYMMETRIC == true if and only if B_zero_point == 0 "
      "or r.row_offsets == nullptr");
  assert(
      (HAS_BIAS == (r.bias != nullptr)) &&
      "HAS_BIAS == true if and only if bias != nullptr");

  __m256i A_zero_point_v = _mm256_set1_epi32(r.A_zero_point);
  __m256i C_zero_point_epi16_v = _mm256_set1_epi16(r.C_zero_point);
  __m256i C_zero_point_epi8_v = _mm256_set1_epi8(r.C_zero_point);

  __m256i permute_mask_v =
      _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);

  constexpr int VLEN = 8;
  for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
    // Scale row_offset with Bq_zero_point
    int32_t row_offset = 0;
    if (B_SYMMETRIC) {
      row_offset = 0;
    } else if (
        Q_GRAN == QuantizationGranularity::TENSOR ||
        Q_GRAN == QuantizationGranularity::GROUP) {
      row_offset =
          r.row_offsets[i - block.row_start] * r.B_zero_point[quant_param_idx];
    } else {
      assert(
          Q_GRAN == QuantizationGranularity::OUT_CHANNEL &&
          "unknown quantization granularity");
    }
    __m256i row_offset_v = _mm256_set1_epi32(row_offset);

    int j = block.col_start;
    for (; j < block.col_start + (block.col_size / (VLEN * 4) * (VLEN * 4));
         j += (VLEN * 4)) {
      __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
          inp + (i - block.row_start) * ld_in + (j - block.col_start)));
      __m256i y_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
          inp + (i - block.row_start) * ld_in + (j - block.col_start) +
          1 * VLEN));
      __m256i z_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
          inp + (i - block.row_start) * ld_in + (j - block.col_start) +
          2 * VLEN));
      __m256i w_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
          inp + (i - block.row_start) * ld_in + (j - block.col_start) +
          3 * VLEN));

      if (!A_SYMMETRIC) {
        __m256i col_off_v;
        if (DIRECT == false) {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v,
              _mm256_loadu_si256(
                  reinterpret_cast<const __m256i*>(r.col_offsets + j)));
        } else {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v,
              _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                  r.col_offsets + j + i * block.col_size)));
        }

        x_v = _mm256_sub_epi32(x_v, col_off_v);

        if (DIRECT == false) {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v,
              _mm256_loadu_si256(
                  reinterpret_cast<const __m256i*>(r.col_offsets + j + VLEN)));
        } else {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v,
              _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                  r.col_offsets + j + VLEN + i * block.col_size)));
        }

        y_v = _mm256_sub_epi32(y_v, col_off_v);

        if (DIRECT == false) {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v,
              _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                  r.col_offsets + j + 2 * VLEN)));
        } else {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v,
              _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                  r.col_offsets + j + 2 * VLEN + i * block.col_size)));
        }

        z_v = _mm256_sub_epi32(z_v, col_off_v);

        if (DIRECT == false) {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v,
              _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                  r.col_offsets + j + 3 * VLEN)));
        } else {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v,
              _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                  r.col_offsets + j + 3 * VLEN + i * block.col_size)));
        }

        w_v = _mm256_sub_epi32(w_v, col_off_v);
      }

      if (!B_SYMMETRIC) {
        if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
          row_offset_v = _mm256_mullo_epi32(
              _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
              _mm256_loadu_si256(
                  reinterpret_cast<const __m256i*>(r.B_zero_point + j)));
        }
        x_v = _mm256_sub_epi32(x_v, row_offset_v);
        if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
          row_offset_v = _mm256_mullo_epi32(
              _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
              _mm256_loadu_si256(
                  reinterpret_cast<const __m256i*>(r.B_zero_point + j + VLEN)));
        }
        y_v = _mm256_sub_epi32(y_v, row_offset_v);
        if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
          row_offset_v = _mm256_mullo_epi32(
              _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
              _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                  r.B_zero_point + j + 2 * VLEN)));
        }
        z_v = _mm256_sub_epi32(z_v, row_offset_v);
        if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
          row_offset_v = _mm256_mullo_epi32(
              _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
              _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                  r.B_zero_point + j + 3 * VLEN)));
        }
        w_v = _mm256_sub_epi32(w_v, row_offset_v);
      }
      __m256 xf_v, yf_v, zf_v, wf_v;
      if (HAS_BIAS) {
        if (is_same<BIAS_TYPE, float>::value) {
          __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
          if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
            x_bias_v = _mm256_div_ps(
                _mm256_loadu_ps(
                    reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
                _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
            y_bias_v = _mm256_div_ps(
                _mm256_loadu_ps(
                    reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
                _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
            z_bias_v = _mm256_div_ps(
                _mm256_loadu_ps(
                    reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
                _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
            w_bias_v = _mm256_div_ps(
                _mm256_loadu_ps(
                    reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
                _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
          } else {
            x_bias_v = _mm256_mul_ps(
                _mm256_loadu_ps(
                    reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
                act_times_w_rcp_v);
            y_bias_v = _mm256_mul_ps(
                _mm256_loadu_ps(
                    reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
                act_times_w_rcp_v);
            z_bias_v = _mm256_mul_ps(
                _mm256_loadu_ps(
                    reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
                act_times_w_rcp_v);
            w_bias_v = _mm256_mul_ps(
                _mm256_loadu_ps(
                    reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
                act_times_w_rcp_v);
          }
          xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
          yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
          zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
          wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
        } else {
          x_v = _mm256_add_epi32(
              x_v,
              _mm256_loadu_si256(
                  reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN)));
          y_v = _mm256_add_epi32(
              y_v,
              _mm256_loadu_si256(
                  reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN)));
          z_v = _mm256_add_epi32(
              z_v,
              _mm256_loadu_si256(
                  reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
          w_v = _mm256_add_epi32(
              w_v,
              _mm256_loadu_si256(
                  reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
          xf_v = _mm256_cvtepi32_ps(x_v);
          yf_v = _mm256_cvtepi32_ps(y_v);
          zf_v = _mm256_cvtepi32_ps(z_v);
          wf_v = _mm256_cvtepi32_ps(w_v);
        }
      } else {
        xf_v = _mm256_cvtepi32_ps(x_v);
        yf_v = _mm256_cvtepi32_ps(y_v);
        zf_v = _mm256_cvtepi32_ps(z_v);
        wf_v = _mm256_cvtepi32_ps(w_v);
      }

      /*
       * Convert int32_t input to FP32 and multiply by FP32 scale.
       * Both operations involve statistically unbiased roundings (with
       * default MXCSR rounding mode):
       * - Large int32_t values can't be exactly represented as FP32.
       * CVTDQ2PS instruction on x86 would round it according to nearest
       * FP32 value with ties to even (assuming default MXCSR rounding
       * mode).
       * - Product of two FP32 values is generally not exactly
       * representation as an FP32 value, and will be rounded to nearest
       * FP32 value with ties to even with default MXCSR rounding mode.
       */
      __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
      if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
        x_scaled_v =
            _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN));
        y_scaled_v =
            _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN));
        z_scaled_v =
            _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
        w_scaled_v =
            _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
      } else {
        x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
        y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
        z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
        w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
      }

      /*
       * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction.
       * CVTPS2DQ instruction rounds result according to nearest FP32 value
       * with ties to even (assuming default MXCSR rounding mode). However,
       * when conversion overflows, it produces INT32_MIN as a result. For
       * large positive inputs the result of conversion can become negative,
       * which affects the final requantization result. Note that on x86
       * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This
       * happens because float(INT32_MAX) rounds to 2**31, which overflows
       * int32_t when it is converted back to integer.
       *
       * Thankfully, we can prove that overflow never happens in this
       * requantization scheme. The largest positive input is INT32_MAX
       * (2**31 - 1), which turns into 2**31 when converted to float. The
       * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the
       * result is 2147483520 (compare to INT32_MAX = 2147483647), which
       * fits into int32_t without overflow.
       */
      __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
      __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
      __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
      __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);

      /*
       * Standard final sequence on x86 AVX2:
       * - Pack to int16_t and saturate
       * - Add zero point
       * - Pack to uint8_t and saturate
       * - Clamp between qmin and qmax
       */
      __m256i xy_packed_v = _mm256_adds_epi16(
          _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
      __m256i zw_packed_v = _mm256_adds_epi16(
          _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
      __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
      __m256i xyzw_clamped_v = _mm256_max_epu8(
          FUSE_RELU ? C_zero_point_epi8_v : min_v,
          _mm256_min_epu8(xyzw_packed_v, max_v));

      /*
       * xyzw_clamped_v has results in the following layout so we need to
       * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7
       */
      xyzw_clamped_v =
          _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);

      /*
       * 4x CVTDQ2PS
       * 4x MULPS
       * 4x CVTPS2DQ
       * 2x PACKSSDW
       * 1x PACKUSWB
       * 2x PADDW
       * 1x PMAXUB
       * 1x PMINUB
       * 1x PERMD
       * ---------------------
       * 20 instructions total
       */
      _mm256_storeu_si256(
          reinterpret_cast<__m256i*>(out + i * ld_out + j), xyzw_clamped_v);
    } // j loop vectorized and unrolled 4x

    for (; j < block.col_start + (block.col_size / VLEN * VLEN); j += VLEN) {
      __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
          inp + (i - block.row_start) * ld_in + (j - block.col_start)));

      if (!A_SYMMETRIC) {
        __m256i col_off_v;
        if (DIRECT == false) {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v,
              _mm256_loadu_si256(
                  reinterpret_cast<const __m256i*>(r.col_offsets + j)));
        } else {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v,
              _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
                  r.col_offsets + j + i * block.col_size)));
        }
        x_v = _mm256_sub_epi32(x_v, col_off_v);
      }

      if (!B_SYMMETRIC) {
        if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
          row_offset_v = _mm256_mullo_epi32(
              _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
              _mm256_loadu_si256(
                  reinterpret_cast<const __m256i*>(r.B_zero_point + j)));
        }
        x_v = _mm256_sub_epi32(x_v, row_offset_v);
      }
      __m256 xf_v;
      if (HAS_BIAS) {
        if (is_same<BIAS_TYPE, float>::value) {
          __m256 x_bias_v;
          if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
            x_bias_v = _mm256_div_ps(
                _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
                _mm256_loadu_ps(r.act_times_w_scale + j));
          } else {
            x_bias_v = _mm256_mul_ps(
                _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
                act_times_w_rcp_v);
          }
          xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
        } else {
          x_v = _mm256_add_epi32(
              x_v,
              _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
          xf_v = _mm256_cvtepi32_ps(x_v);
        }
      } else {
        xf_v = _mm256_cvtepi32_ps(x_v);
      }

      __m256 x_scaled_v;
      if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
        x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j));
      } else {
        x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
      }
      __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);

      __m256i x_packed_v = _mm256_adds_epi16(
          _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
          C_zero_point_epi16_v);
      x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
      __m256i x_clamped_v = _mm256_max_epu8(
          FUSE_RELU ? C_zero_point_epi8_v : min_v,
          _mm256_min_epu8(x_packed_v, max_v));

      /*
       * x_clamped_v has results in the following layout so we need to
       * permute: x0-3 garbage0-11 x4-7 garbage12-23
       */
      x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);

      /*
       * 1x CVTDQ2PS
       * 1x MULPS
       * 1x CVTPS2DQ
       * 1x PACKSSDW
       * 1x PACKUSWB
       * 1x PADDW
       * 1x PMAXUB
       * 1x PMINUB
       * 1x PERMD
       * ---------------------
       * 9 instructions total
       */
      _mm_storel_epi64(
          reinterpret_cast<__m128i*>(out + i * ld_out + j),
          _mm256_castsi256_si128(x_clamped_v));
    } // j loop vectorized

    int remainder = block.col_start + block.col_size - j;
    if (remainder > 0) {
      __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
          internal::avx2_ps_or_epi32_masks[remainder]));

      __m256i x_v = _mm256_maskload_epi32(
          inp + (i - block.row_start) * ld_in + (j - block.col_start), mask_v);

      if (!A_SYMMETRIC) {
        __m256i col_off_v;
        if (DIRECT == false) {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v, _mm256_maskload_epi32(r.col_offsets + j, mask_v));
        } else {
          col_off_v = _mm256_mullo_epi32(
              A_zero_point_v,
              _mm256_maskload_epi32(
                  r.col_offsets + j + i * block.col_size, mask_v));
        }
        x_v = _mm256_sub_epi32(x_v, col_off_v);
      }

      if (!B_SYMMETRIC) {
        if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
          row_offset_v = _mm256_mullo_epi32(
              _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
              _mm256_maskload_epi32(r.B_zero_point + j, mask_v));
        }
        x_v = _mm256_sub_epi32(x_v, row_offset_v);
      }

      __m256 xf_v;
      if (HAS_BIAS) {
        if (is_same<BIAS_TYPE, float>::value) {
          __m256 x_bias_v;
          if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
            x_bias_v = _mm256_div_ps(
                _mm256_maskload_ps(
                    reinterpret_cast<const float*>(r.bias + j), mask_v),
                _mm256_maskload_ps(r.act_times_w_scale + j, mask_v));
          } else {
            x_bias_v = _mm256_mul_ps(
                _mm256_maskload_ps(
                    reinterpret_cast<const float*>(r.bias + j), mask_v),
                act_times_w_rcp_v);
          }
          xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
        } else {
          x_v = _mm256_add_epi32(
              x_v,
              _mm256_maskload_epi32(
                  reinterpret_cast<const int*>(r.bias + j), mask_v));
          xf_v = _mm256_cvtepi32_ps(x_v);
        }
      } else {
        xf_v = _mm256_cvtepi32_ps(x_v);
      }

      __m256 x_scaled_v;
      if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
        x_scaled_v =
            _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v));
      } else {
        x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
      }
      __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);

      __m256i x_packed_v = _mm256_adds_epi16(
          _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
          C_zero_point_epi16_v);
      x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
      __m256i x_clamped_v = _mm256_max_epu8(
          FUSE_RELU ? C_zero_point_epi8_v : min_v,
          _mm256_min_epu8(x_packed_v, max_v));

      /*
       * x_clamped_v has results in the following layout so we need to
       * permute: x0-3 garbage0-11 x4-7 garbage12-23
       */
      x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);

      /*
       * 1x CVTDQ2PS
       * 1x MULPS
       * 1x CVTPS2DQ
       * 1x PACKSSDW
       * 1x PACKUSWB
       * 1x PADDW
       * 1x PMAXUB
       * 1x PMINUB
       * 1x PERMD
       * ---------------------
       * 9 instructions total
       */
      alignas(64) uint8_t x_clamped_buffer[32];
      _mm256_store_si256(
          reinterpret_cast<__m256i*>(x_clamped_buffer), x_clamped_v);
      for (int k = 0; k < remainder; ++k) {
        out[i * ld_out + j + k] = x_clamped_buffer[k];
      }
    } // j loop remainder
  } // i loop
}