in src/QuantUtilsAvx2.cc [1562:1754]
void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output) {
static_assert(
std::is_same<InputType, float>() || std::is_same<InputType, float16>(),
"Only float and float16 types are allowed.");
constexpr int VLEN = 8;
constexpr int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
int output_columns =
(input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE +
2 * sizeof(std::uint16_t);
float* input_row_float_for_fp16;
if (std::is_same<InputType, float16>()) {
input_row_float_for_fp16 = static_cast<float*>(
fbgemmAlignedAlloc(64, input_columns * sizeof(float)));
}
for (size_t row = 0; row < input_rows; ++row) {
const InputType* input_row = input + row * input_columns;
const float* input_row_float;
if (std::is_same<InputType, float>()) {
// NOTE: this reinterpret_cast is only to workaround c++
// type requirements -- it is not for fp16 case and `input_row` HAS to be
// float* type. Remove it and use constexpr when pytorch allows C++17.
input_row_float = reinterpret_cast<const float*>(input_row);
} else {
input_row_float = input_row_float_for_fp16;
}
std::uint8_t* output_row = output + row * output_columns;
std::uint16_t* output_row_scale_bias = reinterpret_cast<std::uint16_t*>(
output_row +
(input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE);
float minimum_element = FLT_MAX;
float maximum_element = -FLT_MAX;
__m256 min_v = _mm256_set1_ps(minimum_element);
__m256 max_v = _mm256_set1_ps(maximum_element);
int col;
for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) {
__m256 in_v;
if (std::is_same<InputType, float>()) {
in_v = _mm256_loadu_ps(input_row_float + col);
} else {
__m128i in_half_v =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(input_row + col));
in_v = _mm256_cvtph_ps(in_half_v);
_mm256_store_ps(input_row_float_for_fp16 + col, in_v);
}
min_v = _mm256_min_ps(min_v, in_v);
max_v = _mm256_max_ps(max_v, in_v);
}
alignas(64) float min_buf[VLEN], max_buf[VLEN];
_mm256_store_ps(min_buf, min_v);
_mm256_store_ps(max_buf, max_v);
for (int i = 0; i < VLEN; ++i) {
minimum_element = std::min(minimum_element, min_buf[i]);
maximum_element = std::max(maximum_element, max_buf[i]);
}
for (; col < input_columns; ++col) {
if (std::is_same<InputType, float>()) {
minimum_element = std::min(minimum_element, input_row_float[col]);
maximum_element = std::max(maximum_element, input_row_float[col]);
} else {
float element = halfToFloat(input_row[col]);
input_row_float_for_fp16[col] = element;
minimum_element = std::min(minimum_element, element);
maximum_element = std::max(maximum_element, element);
}
}
output_row_scale_bias[1] = floatToHalf(minimum_element);
minimum_element = halfToFloat(output_row_scale_bias[1]);
const float range = maximum_element - minimum_element;
float scale = range == 0 ? 1.0f : range / ((1 << BIT_RATE) - 1);
std::uint16_t scale_fp16 = floatToHalf(scale);
scale = halfToFloat(scale_fp16);
if (scale == 0) {
// Corner case handling when maximum_element == minimum_element
// Any scale would work because maximum_element - minimum_element will be
// 0 for all X
scale = 1.0f;
}
float inverse_scale = 1.0f / scale;
if (std::isinf(inverse_scale)) {
scale = 1.0f;
inverse_scale = 1.0f;
}
output_row_scale_bias[0] = floatToHalf(scale);
col = 0;
if (BIT_RATE == 2 || BIT_RATE == 4) {
__m256i permute_mask1_v =
_mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
__m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
min_v = _mm256_set1_ps(minimum_element);
for (; col + 4 * VLEN <= input_columns; col += 4 * VLEN) {
__m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
_mm256_sub_ps(_mm256_loadu_ps(input_row_float + col), min_v),
inverse_scale_v));
__m256i y_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
_mm256_sub_ps(_mm256_loadu_ps(input_row_float + col + VLEN), min_v),
inverse_scale_v));
__m256i z_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
_mm256_sub_ps(
_mm256_loadu_ps(input_row_float + col + 2 * VLEN), min_v),
inverse_scale_v));
__m256i w_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
_mm256_sub_ps(
_mm256_loadu_ps(input_row_float + col + 3 * VLEN), min_v),
inverse_scale_v));
// An instruction sequence to save 32 32-bit integers as 8-bit integers
__m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
__m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
__m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
xyzw_packed_v =
_mm256_permutevar8x32_epi32(xyzw_packed_v, permute_mask1_v);
// saturate to BIT_RATE
xyzw_packed_v = _mm256_min_epu8(
xyzw_packed_v,
_mm256_set1_epi8(static_cast<char>((1 << BIT_RATE) - 1)));
if (BIT_RATE == 4) {
// pack into lower 8-bit of each 16-bit
xyzw_packed_v = _mm256_and_si256(
_mm256_or_si256(
xyzw_packed_v, _mm256_srli_epi16(xyzw_packed_v, 4)),
_mm256_set1_epi16(0x00ff));
} else {
// pack into lower 8-bit of each 32-bit
xyzw_packed_v = _mm256_and_si256(
_mm256_or_si256(
_mm256_or_si256(
xyzw_packed_v, _mm256_srli_epi32(xyzw_packed_v, 6)),
_mm256_or_si256(
_mm256_srli_epi32(xyzw_packed_v, 8 + 4),
_mm256_srli_epi32(xyzw_packed_v, 2 * 8 + 2))),
_mm256_set1_epi32(0x00ff));
}
__m128i out_v;
if (BIT_RATE == 4) {
// avx2 doesn't have _mm256_cvtepi16_epi8
out_v = _mm_packus_epi16(
_mm256_castsi256_si128(xyzw_packed_v),
_mm256_extractf128_si256(xyzw_packed_v, 1));
_mm_storeu_si128(
reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE),
out_v);
} else {
// avx2 doesn't have _mm256_cvtepi32_epi8
out_v = _mm_packus_epi32(
_mm256_castsi256_si128(xyzw_packed_v),
_mm256_extractf128_si256(xyzw_packed_v, 1));
out_v = _mm_packus_epi16(out_v, out_v);
_mm_storel_epi64(
reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE),
out_v);
}
}
}
for (; col < input_columns; ++col) {
float X = input_row_float[col];
std::uint8_t quantized = std::max(
0,
std::min<int>(
std::lrintf((X - minimum_element) * inverse_scale),
(1 << BIT_RATE) - 1));
if (col % NUM_ELEM_PER_BYTE == 0) {
output_row[col / NUM_ELEM_PER_BYTE] = quantized;
} else {
output_row[col / NUM_ELEM_PER_BYTE] |=
(quantized << ((col % NUM_ELEM_PER_BYTE) * BIT_RATE));
}
}
} // for each row
if (std::is_same<InputType, float16>()) {
fbgemmAlignedFree(input_row_float_for_fp16);
}
}