void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2()

in src/QuantUtilsAvx2.cc [1562:1754]


void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
    const InputType* input,
    size_t input_rows,
    int input_columns,
    std::uint8_t* output) {
  static_assert(
      std::is_same<InputType, float>() || std::is_same<InputType, float16>(),
      "Only float and float16 types are allowed.");
  constexpr int VLEN = 8;
  constexpr int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
  int output_columns =
      (input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE +
      2 * sizeof(std::uint16_t);

  float* input_row_float_for_fp16;
  if (std::is_same<InputType, float16>()) {
    input_row_float_for_fp16 = static_cast<float*>(
        fbgemmAlignedAlloc(64, input_columns * sizeof(float)));
  }

  for (size_t row = 0; row < input_rows; ++row) {
    const InputType* input_row = input + row * input_columns;
    const float* input_row_float;
    if (std::is_same<InputType, float>()) {
      // NOTE: this reinterpret_cast is only to workaround c++
      // type requirements -- it is not for fp16 case and `input_row` HAS to be
      // float* type. Remove it and use constexpr when pytorch allows C++17.
      input_row_float = reinterpret_cast<const float*>(input_row);
    } else {
      input_row_float = input_row_float_for_fp16;
    }

    std::uint8_t* output_row = output + row * output_columns;
    std::uint16_t* output_row_scale_bias = reinterpret_cast<std::uint16_t*>(
        output_row +
        (input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE);

    float minimum_element = FLT_MAX;
    float maximum_element = -FLT_MAX;
    __m256 min_v = _mm256_set1_ps(minimum_element);
    __m256 max_v = _mm256_set1_ps(maximum_element);

    int col;
    for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) {
      __m256 in_v;
      if (std::is_same<InputType, float>()) {
        in_v = _mm256_loadu_ps(input_row_float + col);
      } else {
        __m128i in_half_v =
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_row + col));
        in_v = _mm256_cvtph_ps(in_half_v);
        _mm256_store_ps(input_row_float_for_fp16 + col, in_v);
      }

      min_v = _mm256_min_ps(min_v, in_v);
      max_v = _mm256_max_ps(max_v, in_v);
    }
    alignas(64) float min_buf[VLEN], max_buf[VLEN];
    _mm256_store_ps(min_buf, min_v);
    _mm256_store_ps(max_buf, max_v);
    for (int i = 0; i < VLEN; ++i) {
      minimum_element = std::min(minimum_element, min_buf[i]);
      maximum_element = std::max(maximum_element, max_buf[i]);
    }

    for (; col < input_columns; ++col) {
      if (std::is_same<InputType, float>()) {
        minimum_element = std::min(minimum_element, input_row_float[col]);
        maximum_element = std::max(maximum_element, input_row_float[col]);
      } else {
        float element = halfToFloat(input_row[col]);
        input_row_float_for_fp16[col] = element;
        minimum_element = std::min(minimum_element, element);
        maximum_element = std::max(maximum_element, element);
      }
    }

    output_row_scale_bias[1] = floatToHalf(minimum_element);
    minimum_element = halfToFloat(output_row_scale_bias[1]);
    const float range = maximum_element - minimum_element;

    float scale = range == 0 ? 1.0f : range / ((1 << BIT_RATE) - 1);
    std::uint16_t scale_fp16 = floatToHalf(scale);
    scale = halfToFloat(scale_fp16);
    if (scale == 0) {
      // Corner case handling when maximum_element == minimum_element
      // Any scale would work because maximum_element - minimum_element will be
      // 0 for all X
      scale = 1.0f;
    }
    float inverse_scale = 1.0f / scale;
    if (std::isinf(inverse_scale)) {
      scale = 1.0f;
      inverse_scale = 1.0f;
    }

    output_row_scale_bias[0] = floatToHalf(scale);

    col = 0;
    if (BIT_RATE == 2 || BIT_RATE == 4) {
      __m256i permute_mask1_v =
          _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
      __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
      min_v = _mm256_set1_ps(minimum_element);

      for (; col + 4 * VLEN <= input_columns; col += 4 * VLEN) {
        __m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
            _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col), min_v),
            inverse_scale_v));
        __m256i y_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
            _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col + VLEN), min_v),
            inverse_scale_v));
        __m256i z_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
            _mm256_sub_ps(
                _mm256_loadu_ps(input_row_float + col + 2 * VLEN), min_v),
            inverse_scale_v));
        __m256i w_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
            _mm256_sub_ps(
                _mm256_loadu_ps(input_row_float + col + 3 * VLEN), min_v),
            inverse_scale_v));

        // An instruction sequence to save 32 32-bit integers as 8-bit integers
        __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
        __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
        __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
        xyzw_packed_v =
            _mm256_permutevar8x32_epi32(xyzw_packed_v, permute_mask1_v);

        // saturate to BIT_RATE
        xyzw_packed_v = _mm256_min_epu8(
            xyzw_packed_v,
            _mm256_set1_epi8(static_cast<char>((1 << BIT_RATE) - 1)));

        if (BIT_RATE == 4) {
          // pack into lower 8-bit of each 16-bit
          xyzw_packed_v = _mm256_and_si256(
              _mm256_or_si256(
                  xyzw_packed_v, _mm256_srli_epi16(xyzw_packed_v, 4)),
              _mm256_set1_epi16(0x00ff));
        } else {
          // pack into lower 8-bit of each 32-bit
          xyzw_packed_v = _mm256_and_si256(
              _mm256_or_si256(
                  _mm256_or_si256(
                      xyzw_packed_v, _mm256_srli_epi32(xyzw_packed_v, 6)),
                  _mm256_or_si256(
                      _mm256_srli_epi32(xyzw_packed_v, 8 + 4),
                      _mm256_srli_epi32(xyzw_packed_v, 2 * 8 + 2))),
              _mm256_set1_epi32(0x00ff));
        }

        __m128i out_v;
        if (BIT_RATE == 4) {
          // avx2 doesn't have _mm256_cvtepi16_epi8
          out_v = _mm_packus_epi16(
              _mm256_castsi256_si128(xyzw_packed_v),
              _mm256_extractf128_si256(xyzw_packed_v, 1));
          _mm_storeu_si128(
              reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE),
              out_v);
        } else {
          // avx2 doesn't have _mm256_cvtepi32_epi8
          out_v = _mm_packus_epi32(
              _mm256_castsi256_si128(xyzw_packed_v),
              _mm256_extractf128_si256(xyzw_packed_v, 1));
          out_v = _mm_packus_epi16(out_v, out_v);
          _mm_storel_epi64(
              reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE),
              out_v);
        }
      }
    }

    for (; col < input_columns; ++col) {
      float X = input_row_float[col];
      std::uint8_t quantized = std::max(
          0,
          std::min<int>(
              std::lrintf((X - minimum_element) * inverse_scale),
              (1 << BIT_RATE) - 1));
      if (col % NUM_ELEM_PER_BYTE == 0) {
        output_row[col / NUM_ELEM_PER_BYTE] = quantized;
      } else {
        output_row[col / NUM_ELEM_PER_BYTE] |=
            (quantized << ((col % NUM_ELEM_PER_BYTE) * BIT_RATE));
      }
    }
  } // for each row

  if (std::is_same<InputType, float16>()) {
    fbgemmAlignedFree(input_row_float_for_fp16);
  }
}