void requantizeForFloatAvx2()

in src/QuantUtilsAvx2.cc [994:1132]


void requantizeForFloatAvx2(
    float* out,
    const int32_t* inp,
    const block_type_t& block,
    int ld_out,
    int ld_in,
    const requantizationForFloatParams_t& r) {
  // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
  // using AVX2 instructions
  int quant_param_idx = 0;
  if (Q_GRAN == QuantizationGranularity::GROUP) {
    int ncol_per_group = r.ncols / r.groups;
    int g = block.col_start / ncol_per_group;
    quant_param_idx = g;
  }
  __m256 multiplier_v = _mm256_set1_ps(r.A_scale * r.B_scale[quant_param_idx]);

  assert(
      (A_SYMMETRIC == (r.A_zero_point == 0)) &&
      "A_SYMMETRIC == true if and only if A_zero_point == 0");
  assert(
      (B_SYMMETRIC ==
       ((Q_GRAN == QuantizationGranularity::TENSOR && r.B_zero_point[0] == 0) ||
        r.row_offsets == nullptr)) &&
      "B_SYMMETRIC == true if and only if B_zero_point == 0 "
      "or r.row_offsets == nullptr");
  assert(
      (HAS_BIAS == (r.bias != nullptr)) &&
      "HAS_BIAS == true if and only if bias != nullptr");

  __m256i A_zero_point_v = _mm256_set1_epi32(r.A_zero_point);

  constexpr int VLEN = 8;
  for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
    // Scale row_offset with Bq_zero_point
    int32_t row_offset = 0;
    if (B_SYMMETRIC) {
      row_offset = 0;
    } else if (
        Q_GRAN == QuantizationGranularity::TENSOR ||
        Q_GRAN == QuantizationGranularity::GROUP) {
      row_offset =
          r.row_offsets[i - block.row_start] * r.B_zero_point[quant_param_idx];
    } else {
      assert(
          Q_GRAN == QuantizationGranularity::OUT_CHANNEL &&
          "unknown quantization granularity");
    }
    __m256i row_offset_v = _mm256_set1_epi32(row_offset);

    int j = block.col_start;
    for (; j < block.col_start + (block.col_size / VLEN * VLEN); j += VLEN) {
      __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
          inp + (i - block.row_start) * ld_in + (j - block.col_start)));

      if (!A_SYMMETRIC) {
        __m256i col_off_v = _mm256_mullo_epi32(
            A_zero_point_v,
            _mm256_loadu_si256(
                reinterpret_cast<const __m256i*>(r.col_offsets + j)));
        x_v = _mm256_sub_epi32(x_v, col_off_v);
      }

      if (!B_SYMMETRIC) {
        if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
          row_offset_v = _mm256_mullo_epi32(
              _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
              _mm256_loadu_si256(
                  reinterpret_cast<const __m256i*>(r.B_zero_point + j)));
        }
        x_v = _mm256_sub_epi32(x_v, row_offset_v);
      }

      __m256 x_scaled_v;
      if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
        x_scaled_v = _mm256_mul_ps(
            _mm256_cvtepi32_ps(x_v),
            _mm256_mul_ps(
                _mm256_set1_ps(r.A_scale), _mm256_loadu_ps(r.B_scale + j)));
      } else {
        x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
      }

      if (HAS_BIAS) {
        x_scaled_v = _mm256_add_ps(x_scaled_v, _mm256_loadu_ps(r.bias + j));
      }
      if (FUSE_RELU) {
        x_scaled_v = _mm256_max_ps(_mm256_setzero_ps(), x_scaled_v);
      }

      _mm256_storeu_ps(out + i * ld_out + j, x_scaled_v);
    } // j loop vectorized

    int remainder = block.col_start + block.col_size - j;
    if (remainder > 0) {
      __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
          internal::avx2_ps_or_epi32_masks[remainder]));

      __m256i x_v = _mm256_maskload_epi32(
          inp + (i - block.row_start) * ld_in + (j - block.col_start), mask_v);

      if (!A_SYMMETRIC) {
        __m256i col_off_v = _mm256_mullo_epi32(
            A_zero_point_v, _mm256_maskload_epi32(r.col_offsets + j, mask_v));
        x_v = _mm256_sub_epi32(x_v, col_off_v);
      }

      if (!B_SYMMETRIC) {
        if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
          row_offset_v = _mm256_mullo_epi32(
              _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
              _mm256_maskload_epi32(r.B_zero_point + j, mask_v));
        }
        x_v = _mm256_sub_epi32(x_v, row_offset_v);
      }

      __m256 x_scaled_v;
      if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
        x_scaled_v = _mm256_mul_ps(
            _mm256_cvtepi32_ps(x_v),
            _mm256_mul_ps(
                _mm256_set1_ps(r.A_scale),
                _mm256_maskload_ps(r.B_scale + j, mask_v)));
      } else {
        x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
      }

      if (HAS_BIAS) {
        x_scaled_v =
            _mm256_add_ps(x_scaled_v, _mm256_maskload_ps(r.bias + j, mask_v));
      }
      if (FUSE_RELU) {
        x_scaled_v = _mm256_max_ps(_mm256_setzero_ps(), x_scaled_v);
      }

      _mm256_maskstore_ps(out + i * ld_out + j, mask_v, x_scaled_v);
    } // j loop remainder
  } // i loop
}