void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2()

in src/QuantUtilsAvx2.cc [1893:2126]


void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
    const std::uint8_t* input,
    size_t input_rows,
    int input_columns,
    OutputType* output) {
  static_assert(
      std::is_same<OutputType, float>() || std::is_same<OutputType, float16>(),
      "Only float and float16 types are allowed.");
  constexpr int VLEN = 8;
  constexpr int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
  int output_columns =
      (input_columns - 2 * sizeof(uint16_t)) * NUM_ELEM_PER_BYTE;

  // Compute a remainder for vector load
  // Since every row is followed by 2 fp16 (scale and bias), luckily
  // we don't need mask at bit-rate granularity but just at 32-bit
  // granularity.
  constexpr int NUM_ELEM_PER_32BIT = 32 / BIT_RATE;
  // multiply by 4 because we're handling 4 vlen per iteration
  constexpr int NUM_OF_32BIT_PER_VLOAD = VLEN * 4 / NUM_ELEM_PER_32BIT;

  int remainder_32bit_granularity, remainder;
  __m128i vmask_load;
  __m256i vmask_store0, vmask_store1, vmask_store2, vmask_store3;
  if (BIT_RATE == 4 || BIT_RATE == 2) {
    remainder_32bit_granularity = (output_columns + NUM_ELEM_PER_32BIT - 1) /
        NUM_ELEM_PER_32BIT % NUM_OF_32BIT_PER_VLOAD;
    vmask_load = _mm_lddqu_si128(reinterpret_cast<const __m128i*>(
        internal::avx2_ps_or_epi32_combined_mask + NUM_OF_32BIT_PER_VLOAD +
        (NUM_OF_32BIT_PER_VLOAD - remainder_32bit_granularity) %
            NUM_OF_32BIT_PER_VLOAD));
    remainder = output_columns % (4 * VLEN);
    int remainder_ratio = 1;
    if (std::is_same<OutputType, float16>()) {
      // For fp16 we only need half of the mask.
      //
      // For instance, if reminder is 2, for FP32 the masks are
      // {-1, -1, 0, ..., 0}, {0, ..., 0}, {0, ..., 0}, {0, ..., 0}
      // (8 32-bit integers for each mask)
      // for FP16 we only need
      // {-1, 0, 0, 0}, {0, ..., 0}, {0, ..., 0}, {0, ..., 0}
      // (4 32-bit integers for each mask)
      // since we reinterpret 2 FP16 numbers as one 32-bit number.
      // NOTE: for bit_rate 4 or 2, reminders are always multiple of 2 or 4,
      // so we do have to worry about odd number of FP16 numbers.
      //
      // Or, if reminder is 30, for FP32 the masks are
      // {-1, ..., -1}, {-1, ..., -1}, {-1, ..., -1}, {-1, .., -1, 0, 0}
      // for FP16 we only need
      // {-1, ..., -1}, {-1, ..., -1}, {-1, ..., -1}, {-1, -1, -1, 0}
      remainder_ratio = 2;
    }
    vmask_store0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
        internal::avx2_ps_or_epi32_combined_mask +
        (VLEN - std::min(remainder, VLEN) / remainder_ratio % (VLEN + 1))));
    vmask_store1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
        internal::avx2_ps_or_epi32_combined_mask +
        (VLEN -
         std::max(0, std::min(remainder - VLEN, VLEN) / remainder_ratio) %
             (VLEN + 1))));
    vmask_store2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
        internal::avx2_ps_or_epi32_combined_mask +
        (VLEN -
         std::max(0, std::min(remainder - 2 * VLEN, VLEN) / remainder_ratio) %
             (VLEN + 1))));
    vmask_store3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
        internal::avx2_ps_or_epi32_combined_mask +
        (VLEN -
         std::max(0, std::min(remainder - 3 * VLEN, VLEN) / remainder_ratio) %
             (VLEN + 1))));
  }

  for (size_t row = 0; row < input_rows; ++row) {
    const std::uint8_t* input_row = input + row * input_columns;
    const uint16_t* input_row_scale_bias = reinterpret_cast<const uint16_t*>(
        input_row +
        (output_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE);
    float scale = halfToFloat(input_row_scale_bias[0]);
    float bias = halfToFloat(input_row_scale_bias[1]);
    OutputType* output_row = output + row * output_columns;
    float* output_row_float;
    if (std::is_same<OutputType, float>()) {
      // NOTE: this reinterpret_cast is only to workaround c++
      // type requirements -- it is not for fp16 case and `output_row` HAS to be
      // float* type. Remove it and use constexpr when pytorch allows C++17.
      output_row_float = reinterpret_cast<float*>(output_row);
    }

    int col = 0;
    if (BIT_RATE == 4 || BIT_RATE == 2) {
      __m256 vscale = _mm256_set1_ps(scale);
      __m256 vbias = _mm256_set1_ps(bias);
      for (; col + 4 * VLEN <= output_columns; col += 4 * VLEN) {
        __m256i vinq;
        // unpack to 8-bit integers
        if (BIT_RATE == 4) {
          vinq = _mm256_cvtepu8_epi16(
              _mm_loadu_si128(reinterpret_cast<const __m128i*>(
                  input_row + col / NUM_ELEM_PER_BYTE)));
          vinq = _mm256_and_si256(
              _mm256_or_si256(vinq, _mm256_slli_epi32(vinq, 4)),
              _mm256_set1_epi16(0x0f0f));
        } else {
          vinq = _mm256_cvtepu8_epi32(
              _mm_loadl_epi64(reinterpret_cast<const __m128i*>(
                  input_row + col / NUM_ELEM_PER_BYTE)));
          vinq = _mm256_and_si256(
              _mm256_or_si256(
                  _mm256_or_si256(
                      _mm256_slli_epi32(vinq, 2 * 8 + 2),
                      _mm256_slli_epi32(vinq, 8 + 4)),
                  _mm256_or_si256(_mm256_slli_epi32(vinq, 6), vinq)),
              _mm256_set1_epi32(0x03030303));
        }
        __m256 vinq0 = _mm256_cvtepi32_ps(
            _mm256_cvtepi8_epi32(_mm256_castsi256_si128(vinq)));
        __m256 vinq1 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
            _mm_set1_epi64x(_mm256_extract_epi64(vinq, 1))));
        __m256 vinq2 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
            _mm_set1_epi64x(_mm256_extract_epi64(vinq, 2))));
        __m256 vinq3 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
            _mm_set1_epi64x(_mm256_extract_epi64(vinq, 3))));
        vinq0 = _mm256_fmadd_ps(vscale, vinq0, vbias);
        vinq1 = _mm256_fmadd_ps(vscale, vinq1, vbias);
        vinq2 = _mm256_fmadd_ps(vscale, vinq2, vbias);
        vinq3 = _mm256_fmadd_ps(vscale, vinq3, vbias);

        if (std::is_same<OutputType, float>()) {
          _mm256_storeu_ps(output_row_float + col, vinq0);
          _mm256_storeu_ps(output_row_float + col + VLEN, vinq1);
          _mm256_storeu_ps(output_row_float + col + 2 * VLEN, vinq2);
          _mm256_storeu_ps(output_row_float + col + 3 * VLEN, vinq3);
        } else {
          _mm_storeu_si128(
              reinterpret_cast<__m128i*>(output_row + col),
              _mm256_cvtps_ph(
                  vinq0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
          _mm_storeu_si128(
              reinterpret_cast<__m128i*>(output_row + col + VLEN),
              _mm256_cvtps_ph(
                  vinq1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
          _mm_storeu_si128(
              reinterpret_cast<__m128i*>(output_row + col + 2 * VLEN),
              _mm256_cvtps_ph(
                  vinq2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
          _mm_storeu_si128(
              reinterpret_cast<__m128i*>(output_row + col + 3 * VLEN),
              _mm256_cvtps_ph(
                  vinq3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
        }
      }

      if (remainder) {
        __m256i vinq;
        if (BIT_RATE == 4) {
          vinq = _mm256_cvtepu8_epi16(_mm_maskload_epi32(
              reinterpret_cast<const int*>(input_row + col / NUM_ELEM_PER_BYTE),
              vmask_load));
          vinq = _mm256_and_si256(
              _mm256_or_si256(vinq, _mm256_slli_epi32(vinq, 4)),
              _mm256_set1_epi16(0x0f0f));
        } else {
          vinq = _mm256_cvtepu8_epi32(_mm_maskload_epi32(
              reinterpret_cast<const int*>(input_row + col / NUM_ELEM_PER_BYTE),
              vmask_load));
          vinq = _mm256_and_si256(
              _mm256_or_si256(
                  _mm256_or_si256(
                      _mm256_slli_epi32(vinq, 2 * 8 + 2),
                      _mm256_slli_epi32(vinq, 8 + 4)),
                  _mm256_or_si256(_mm256_slli_epi32(vinq, 6), vinq)),
              _mm256_set1_epi32(0x03030303));
        }

        __m256 vinq0 = _mm256_cvtepi32_ps(
            _mm256_cvtepi8_epi32(_mm256_castsi256_si128(vinq)));
        __m256 vinq1 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
            _mm_set1_epi64x(_mm256_extract_epi64(vinq, 1))));
        __m256 vinq2 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
            _mm_set1_epi64x(_mm256_extract_epi64(vinq, 2))));
        __m256 vinq3 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
            _mm_set1_epi64x(_mm256_extract_epi64(vinq, 3))));

        vinq0 = _mm256_fmadd_ps(vscale, vinq0, vbias);
        vinq1 = _mm256_fmadd_ps(vscale, vinq1, vbias);
        vinq2 = _mm256_fmadd_ps(vscale, vinq2, vbias);
        vinq3 = _mm256_fmadd_ps(vscale, vinq3, vbias);

        if (std::is_same<OutputType, float>()) {
          _mm256_maskstore_ps(output_row_float + col, vmask_store0, vinq0);
          _mm256_maskstore_ps(
              output_row_float + col + VLEN, vmask_store1, vinq1);
          _mm256_maskstore_ps(
              output_row_float + col + 2 * VLEN, vmask_store2, vinq2);
          _mm256_maskstore_ps(
              output_row_float + col + 3 * VLEN, vmask_store3, vinq3);
        } else {
          _mm_maskstore_epi32(
              reinterpret_cast<int*>(output_row + col),
              _mm256_castsi256_si128(vmask_store0),
              _mm256_cvtps_ph(
                  vinq0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
          _mm_maskstore_epi32(
              reinterpret_cast<int*>(output_row + col + VLEN),
              _mm256_castsi256_si128(vmask_store1),
              _mm256_cvtps_ph(
                  vinq1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
          _mm_maskstore_epi32(
              reinterpret_cast<int*>(output_row + col + 2 * VLEN),
              _mm256_castsi256_si128(vmask_store2),
              _mm256_cvtps_ph(
                  vinq2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
          _mm_maskstore_epi32(
              reinterpret_cast<int*>(output_row + col + 3 * VLEN),
              _mm256_castsi256_si128(vmask_store3),
              _mm256_cvtps_ph(
                  vinq3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
        }
      }
    } else {
      for (; col < output_columns; ++col) {
        std::uint8_t quantized = input_row[col / NUM_ELEM_PER_BYTE];
        quantized >>= (col % NUM_ELEM_PER_BYTE) * BIT_RATE;
        quantized &= (1 << BIT_RATE) - 1;
        float output_value = scale * quantized + bias;
        if (std::is_same<OutputType, float>()) {
          output_row[col] = output_value;
        } else {
          output_row[col] = cpu_float2half_rn(output_value);
        }
      }
    }
  } // for each row
}