in tensorflow/tensorflow/lite/experimental/ruy/kernel_avx512.cc [358:808]
void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
gemmlowp::ScopedProfilingLabel label("Kernel kAvx512");
// As parameters are defined, we need to scale by sizeof(float).
const std::int64_t lhs_stride = params.lhs_stride >> 2;
const std::int64_t dst_stride = params.dst_stride >> 2;
const std::int64_t rhs_stride = params.rhs_stride >> 2;
int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
const int end_row = std::min(params.dst_rows, params.last_row + 16);
const int end_col = std::min(params.dst_cols, params.last_col + 16);
const float* adj_rhs_col_ptr =
params.rhs_base_ptr - params.start_col * rhs_stride;
float* adj_dst_col_ptr =
params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
const float* adj_lhs_col_ptr =
params.lhs_base_ptr - params.start_row * lhs_stride;
const float* bias_col_ptr = params.bias;
const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
int col = params.start_col;
for (; col <= end_col - 16; col += 16) {
const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
int row = params.start_row;
for (; row <= end_row - 16; row += 16) {
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
// Initialize with bias.
const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr);
// Process block in two halves, split by columns.
{
constexpr int mmm = 0;
__m512 accum_data_v0 = initial_accum_data;
__m512 accum_data_v1 = initial_accum_data;
__m512 accum_data_v2 = initial_accum_data;
__m512 accum_data_v3 = initial_accum_data;
__m512 accum_data_v4 = initial_accum_data;
__m512 accum_data_v5 = initial_accum_data;
__m512 accum_data_v6 = initial_accum_data;
__m512 accum_data_v7 = initial_accum_data;
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
for (int d = 0; d < (params.depth - 1); ++d) {
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
lhs_ptr += 16;
rhs_ptr += 16;
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
}
{
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
{
float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
_mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
_mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
_mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
_mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
_mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
_mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
_mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
_mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
}
}
} // Inner half-block loop, unrolled, first iteration.
{
constexpr int mmm = 1;
__m512 accum_data_v0 = initial_accum_data;
__m512 accum_data_v1 = initial_accum_data;
__m512 accum_data_v2 = initial_accum_data;
__m512 accum_data_v3 = initial_accum_data;
__m512 accum_data_v4 = initial_accum_data;
__m512 accum_data_v5 = initial_accum_data;
__m512 accum_data_v6 = initial_accum_data;
__m512 accum_data_v7 = initial_accum_data;
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
for (int d = 0; d < (params.depth - 1); ++d) {
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
lhs_ptr += 16;
rhs_ptr += 16;
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
}
{
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
{
float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
_mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
_mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
_mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
_mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
_mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
_mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
_mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
_mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
}
}
} // Inner half-block loop, unrolled, second iteration.
} // End row-block loop.
// The unrolling within this conditional may be somewhat pointless. It
// depends on the kinds of models.
if (row < end_row) {
const int residual_rows = end_row - row;
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
// Initialize with bias.
const __mmask16 row_mask =
(static_cast<std::uint32_t>(1) << residual_rows) - 1;
const __m512 initial_accum_data =
_mm512_maskz_loadu_ps(row_mask, bias_ptr);
// Process block in two halves, split by columns.
for (int mmm = 0; mmm < 2; ++mmm) {
__m512 accum_data_v0 = initial_accum_data;
__m512 accum_data_v1 = initial_accum_data;
__m512 accum_data_v2 = initial_accum_data;
__m512 accum_data_v3 = initial_accum_data;
__m512 accum_data_v4 = initial_accum_data;
__m512 accum_data_v5 = initial_accum_data;
__m512 accum_data_v6 = initial_accum_data;
__m512 accum_data_v7 = initial_accum_data;
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
for (int d = 0; d < (params.depth - 1); ++d) {
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
lhs_ptr += 16;
rhs_ptr += 16;
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
}
{
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
{
const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
accum_data_v0 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
accum_data_v1 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
accum_data_v2 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
accum_data_v3 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
accum_data_v4 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
accum_data_v5 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
accum_data_v6 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
accum_data_v7 =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
}
{
float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask,
accum_data_v0);
accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask,
accum_data_v1);
accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask,
accum_data_v2);
accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask,
accum_data_v3);
accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask,
accum_data_v4);
accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask,
accum_data_v5);
accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask,
accum_data_v6);
accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
_mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask,
accum_data_v7);
}
}
} // Inner half-block loop.
} // Residual rows, main col-block loop.
} // End col-block loop.
if (col < end_col) {
RUY_DCHECK_GE(end_col - col, 0);
RUY_DCHECK_LT(end_col - col, 16);
__m512 accum_data_v[8];
const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
for (int row = params.start_row; row < end_row; row += 16) {
const int residual_rows = std::min(end_row - row, 16);
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
// Initialize with bias.
const __mmask16 row_mask =
(static_cast<std::uint32_t>(1) << residual_rows) - 1;
const __m512 initial_accum_data =
_mm512_maskz_loadu_ps(row_mask, bias_ptr);
// Process block in two halves, split by columns.
for (int mmm = 0; mmm < 2; ++mmm) {
for (int j = 0; j < 8; ++j) {
accum_data_v[j] = initial_accum_data;
}
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
for (int d = 0; d < params.depth; ++d) {
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
for (int j = 0; j < 8; ++j) {
const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]);
accum_data_v[j] =
_mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
}
lhs_ptr += 16;
rhs_ptr += 16;
}
const int residual_cols = std::min(end_col - col - 8 * mmm, 8);
if (residual_rows == 16) {
if (residual_cols == 8) {
for (int j = 0; j < 8; ++j) {
float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
_mm512_storeu_ps(block_ptr, accum_data_v[j]);
}
} else {
for (int j = 0; j < residual_cols; ++j) {
float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
_mm512_storeu_ps(block_ptr, accum_data_v[j]);
}
}
} else {
for (int j = 0; j < residual_cols; ++j) {
float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
_mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]);
}
}
} // Inner half-block loop.
} // End row-block loop.
} // Residual cols.
}