in src/QuantUtilsAvx2.cc [1893:2126]
void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
const std::uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output) {
static_assert(
std::is_same<OutputType, float>() || std::is_same<OutputType, float16>(),
"Only float and float16 types are allowed.");
constexpr int VLEN = 8;
constexpr int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
int output_columns =
(input_columns - 2 * sizeof(uint16_t)) * NUM_ELEM_PER_BYTE;
// Compute a remainder for vector load
// Since every row is followed by 2 fp16 (scale and bias), luckily
// we don't need mask at bit-rate granularity but just at 32-bit
// granularity.
constexpr int NUM_ELEM_PER_32BIT = 32 / BIT_RATE;
// multiply by 4 because we're handling 4 vlen per iteration
constexpr int NUM_OF_32BIT_PER_VLOAD = VLEN * 4 / NUM_ELEM_PER_32BIT;
int remainder_32bit_granularity, remainder;
__m128i vmask_load;
__m256i vmask_store0, vmask_store1, vmask_store2, vmask_store3;
if (BIT_RATE == 4 || BIT_RATE == 2) {
remainder_32bit_granularity = (output_columns + NUM_ELEM_PER_32BIT - 1) /
NUM_ELEM_PER_32BIT % NUM_OF_32BIT_PER_VLOAD;
vmask_load = _mm_lddqu_si128(reinterpret_cast<const __m128i*>(
internal::avx2_ps_or_epi32_combined_mask + NUM_OF_32BIT_PER_VLOAD +
(NUM_OF_32BIT_PER_VLOAD - remainder_32bit_granularity) %
NUM_OF_32BIT_PER_VLOAD));
remainder = output_columns % (4 * VLEN);
int remainder_ratio = 1;
if (std::is_same<OutputType, float16>()) {
// For fp16 we only need half of the mask.
//
// For instance, if reminder is 2, for FP32 the masks are
// {-1, -1, 0, ..., 0}, {0, ..., 0}, {0, ..., 0}, {0, ..., 0}
// (8 32-bit integers for each mask)
// for FP16 we only need
// {-1, 0, 0, 0}, {0, ..., 0}, {0, ..., 0}, {0, ..., 0}
// (4 32-bit integers for each mask)
// since we reinterpret 2 FP16 numbers as one 32-bit number.
// NOTE: for bit_rate 4 or 2, reminders are always multiple of 2 or 4,
// so we do have to worry about odd number of FP16 numbers.
//
// Or, if reminder is 30, for FP32 the masks are
// {-1, ..., -1}, {-1, ..., -1}, {-1, ..., -1}, {-1, .., -1, 0, 0}
// for FP16 we only need
// {-1, ..., -1}, {-1, ..., -1}, {-1, ..., -1}, {-1, -1, -1, 0}
remainder_ratio = 2;
}
vmask_store0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
internal::avx2_ps_or_epi32_combined_mask +
(VLEN - std::min(remainder, VLEN) / remainder_ratio % (VLEN + 1))));
vmask_store1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
internal::avx2_ps_or_epi32_combined_mask +
(VLEN -
std::max(0, std::min(remainder - VLEN, VLEN) / remainder_ratio) %
(VLEN + 1))));
vmask_store2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
internal::avx2_ps_or_epi32_combined_mask +
(VLEN -
std::max(0, std::min(remainder - 2 * VLEN, VLEN) / remainder_ratio) %
(VLEN + 1))));
vmask_store3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
internal::avx2_ps_or_epi32_combined_mask +
(VLEN -
std::max(0, std::min(remainder - 3 * VLEN, VLEN) / remainder_ratio) %
(VLEN + 1))));
}
for (size_t row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const uint16_t* input_row_scale_bias = reinterpret_cast<const uint16_t*>(
input_row +
(output_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE);
float scale = halfToFloat(input_row_scale_bias[0]);
float bias = halfToFloat(input_row_scale_bias[1]);
OutputType* output_row = output + row * output_columns;
float* output_row_float;
if (std::is_same<OutputType, float>()) {
// NOTE: this reinterpret_cast is only to workaround c++
// type requirements -- it is not for fp16 case and `output_row` HAS to be
// float* type. Remove it and use constexpr when pytorch allows C++17.
output_row_float = reinterpret_cast<float*>(output_row);
}
int col = 0;
if (BIT_RATE == 4 || BIT_RATE == 2) {
__m256 vscale = _mm256_set1_ps(scale);
__m256 vbias = _mm256_set1_ps(bias);
for (; col + 4 * VLEN <= output_columns; col += 4 * VLEN) {
__m256i vinq;
// unpack to 8-bit integers
if (BIT_RATE == 4) {
vinq = _mm256_cvtepu8_epi16(
_mm_loadu_si128(reinterpret_cast<const __m128i*>(
input_row + col / NUM_ELEM_PER_BYTE)));
vinq = _mm256_and_si256(
_mm256_or_si256(vinq, _mm256_slli_epi32(vinq, 4)),
_mm256_set1_epi16(0x0f0f));
} else {
vinq = _mm256_cvtepu8_epi32(
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(
input_row + col / NUM_ELEM_PER_BYTE)));
vinq = _mm256_and_si256(
_mm256_or_si256(
_mm256_or_si256(
_mm256_slli_epi32(vinq, 2 * 8 + 2),
_mm256_slli_epi32(vinq, 8 + 4)),
_mm256_or_si256(_mm256_slli_epi32(vinq, 6), vinq)),
_mm256_set1_epi32(0x03030303));
}
__m256 vinq0 = _mm256_cvtepi32_ps(
_mm256_cvtepi8_epi32(_mm256_castsi256_si128(vinq)));
__m256 vinq1 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
_mm_set1_epi64x(_mm256_extract_epi64(vinq, 1))));
__m256 vinq2 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
_mm_set1_epi64x(_mm256_extract_epi64(vinq, 2))));
__m256 vinq3 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
_mm_set1_epi64x(_mm256_extract_epi64(vinq, 3))));
vinq0 = _mm256_fmadd_ps(vscale, vinq0, vbias);
vinq1 = _mm256_fmadd_ps(vscale, vinq1, vbias);
vinq2 = _mm256_fmadd_ps(vscale, vinq2, vbias);
vinq3 = _mm256_fmadd_ps(vscale, vinq3, vbias);
if (std::is_same<OutputType, float>()) {
_mm256_storeu_ps(output_row_float + col, vinq0);
_mm256_storeu_ps(output_row_float + col + VLEN, vinq1);
_mm256_storeu_ps(output_row_float + col + 2 * VLEN, vinq2);
_mm256_storeu_ps(output_row_float + col + 3 * VLEN, vinq3);
} else {
_mm_storeu_si128(
reinterpret_cast<__m128i*>(output_row + col),
_mm256_cvtps_ph(
vinq0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
_mm_storeu_si128(
reinterpret_cast<__m128i*>(output_row + col + VLEN),
_mm256_cvtps_ph(
vinq1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
_mm_storeu_si128(
reinterpret_cast<__m128i*>(output_row + col + 2 * VLEN),
_mm256_cvtps_ph(
vinq2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
_mm_storeu_si128(
reinterpret_cast<__m128i*>(output_row + col + 3 * VLEN),
_mm256_cvtps_ph(
vinq3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
}
if (remainder) {
__m256i vinq;
if (BIT_RATE == 4) {
vinq = _mm256_cvtepu8_epi16(_mm_maskload_epi32(
reinterpret_cast<const int*>(input_row + col / NUM_ELEM_PER_BYTE),
vmask_load));
vinq = _mm256_and_si256(
_mm256_or_si256(vinq, _mm256_slli_epi32(vinq, 4)),
_mm256_set1_epi16(0x0f0f));
} else {
vinq = _mm256_cvtepu8_epi32(_mm_maskload_epi32(
reinterpret_cast<const int*>(input_row + col / NUM_ELEM_PER_BYTE),
vmask_load));
vinq = _mm256_and_si256(
_mm256_or_si256(
_mm256_or_si256(
_mm256_slli_epi32(vinq, 2 * 8 + 2),
_mm256_slli_epi32(vinq, 8 + 4)),
_mm256_or_si256(_mm256_slli_epi32(vinq, 6), vinq)),
_mm256_set1_epi32(0x03030303));
}
__m256 vinq0 = _mm256_cvtepi32_ps(
_mm256_cvtepi8_epi32(_mm256_castsi256_si128(vinq)));
__m256 vinq1 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
_mm_set1_epi64x(_mm256_extract_epi64(vinq, 1))));
__m256 vinq2 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
_mm_set1_epi64x(_mm256_extract_epi64(vinq, 2))));
__m256 vinq3 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
_mm_set1_epi64x(_mm256_extract_epi64(vinq, 3))));
vinq0 = _mm256_fmadd_ps(vscale, vinq0, vbias);
vinq1 = _mm256_fmadd_ps(vscale, vinq1, vbias);
vinq2 = _mm256_fmadd_ps(vscale, vinq2, vbias);
vinq3 = _mm256_fmadd_ps(vscale, vinq3, vbias);
if (std::is_same<OutputType, float>()) {
_mm256_maskstore_ps(output_row_float + col, vmask_store0, vinq0);
_mm256_maskstore_ps(
output_row_float + col + VLEN, vmask_store1, vinq1);
_mm256_maskstore_ps(
output_row_float + col + 2 * VLEN, vmask_store2, vinq2);
_mm256_maskstore_ps(
output_row_float + col + 3 * VLEN, vmask_store3, vinq3);
} else {
_mm_maskstore_epi32(
reinterpret_cast<int*>(output_row + col),
_mm256_castsi256_si128(vmask_store0),
_mm256_cvtps_ph(
vinq0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
_mm_maskstore_epi32(
reinterpret_cast<int*>(output_row + col + VLEN),
_mm256_castsi256_si128(vmask_store1),
_mm256_cvtps_ph(
vinq1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
_mm_maskstore_epi32(
reinterpret_cast<int*>(output_row + col + 2 * VLEN),
_mm256_castsi256_si128(vmask_store2),
_mm256_cvtps_ph(
vinq2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
_mm_maskstore_epi32(
reinterpret_cast<int*>(output_row + col + 3 * VLEN),
_mm256_castsi256_si128(vmask_store3),
_mm256_cvtps_ph(
vinq3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
}
} else {
for (; col < output_columns; ++col) {
std::uint8_t quantized = input_row[col / NUM_ELEM_PER_BYTE];
quantized >>= (col % NUM_ELEM_PER_BYTE) * BIT_RATE;
quantized &= (1 << BIT_RATE) - 1;
float output_value = scale * quantized + bias;
if (std::is_same<OutputType, float>()) {
output_row[col] = output_value;
} else {
output_row[col] = cpu_float2half_rn(output_value);
}
}
}
} // for each row
}