void KernelFloatAvx512()

in tensorflow/tensorflow/lite/experimental/ruy/kernel_avx512.cc [358:808]


void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
  gemmlowp::ScopedProfilingLabel label("Kernel kAvx512");

  // As parameters are defined, we need to scale by sizeof(float).
  const std::int64_t lhs_stride = params.lhs_stride >> 2;
  const std::int64_t dst_stride = params.dst_stride >> 2;
  const std::int64_t rhs_stride = params.rhs_stride >> 2;

  int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
  const int end_row = std::min(params.dst_rows, params.last_row + 16);
  const int end_col = std::min(params.dst_cols, params.last_col + 16);

  const float* adj_rhs_col_ptr =
      params.rhs_base_ptr - params.start_col * rhs_stride;
  float* adj_dst_col_ptr =
      params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
  const float* adj_lhs_col_ptr =
      params.lhs_base_ptr - params.start_row * lhs_stride;
  const float* bias_col_ptr = params.bias;

  const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
  const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);

  int col = params.start_col;
  for (; col <= end_col - 16; col += 16) {
    const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
    float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;

    int row = params.start_row;
    for (; row <= end_row - 16; row += 16) {
      const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
      float* dst_ptr = dst_col_ptr + row;
      const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;

      // Initialize with bias.
      const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr);

      // Process block in two halves, split by columns.
      {
        constexpr int mmm = 0;

        __m512 accum_data_v0 = initial_accum_data;
        __m512 accum_data_v1 = initial_accum_data;
        __m512 accum_data_v2 = initial_accum_data;
        __m512 accum_data_v3 = initial_accum_data;
        __m512 accum_data_v4 = initial_accum_data;
        __m512 accum_data_v5 = initial_accum_data;
        __m512 accum_data_v6 = initial_accum_data;
        __m512 accum_data_v7 = initial_accum_data;

        const float* lhs_ptr = lhs_col_ptr;
        const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
        for (int d = 0; d < (params.depth - 1); ++d) {
          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
          lhs_ptr += 16;
          rhs_ptr += 16;

          {
            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
            accum_data_v0 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
            accum_data_v1 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
            accum_data_v2 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
            accum_data_v3 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
            accum_data_v4 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
            accum_data_v5 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
            accum_data_v6 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
            accum_data_v7 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
          }
        }
        {
          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
          {
            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
            accum_data_v0 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
            accum_data_v1 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
            accum_data_v2 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
            accum_data_v3 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
            accum_data_v4 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
            accum_data_v5 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
            accum_data_v6 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
            accum_data_v7 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
          }
          {
            float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
            accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
            accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
            accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
            accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
            accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
            accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
            accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
            accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
            accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
            accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
            accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
            accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
            accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
            accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
            accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
            accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
          }
        }
      }  // Inner half-block loop, unrolled, first iteration.
      {
        constexpr int mmm = 1;

        __m512 accum_data_v0 = initial_accum_data;
        __m512 accum_data_v1 = initial_accum_data;
        __m512 accum_data_v2 = initial_accum_data;
        __m512 accum_data_v3 = initial_accum_data;
        __m512 accum_data_v4 = initial_accum_data;
        __m512 accum_data_v5 = initial_accum_data;
        __m512 accum_data_v6 = initial_accum_data;
        __m512 accum_data_v7 = initial_accum_data;

        const float* lhs_ptr = lhs_col_ptr;
        const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
        for (int d = 0; d < (params.depth - 1); ++d) {
          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
          lhs_ptr += 16;
          rhs_ptr += 16;
          {
            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
            accum_data_v0 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
            accum_data_v1 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
            accum_data_v2 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
            accum_data_v3 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
            accum_data_v4 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
            accum_data_v5 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
            accum_data_v6 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
            accum_data_v7 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
          }
        }
        {
          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
          {
            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
            accum_data_v0 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
            accum_data_v1 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
            accum_data_v2 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
            accum_data_v3 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
            accum_data_v4 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
            accum_data_v5 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
            accum_data_v6 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
            accum_data_v7 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
          }
          {
            float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
            accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
            accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
            accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
            accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
            accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
            accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
            accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
            accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
            accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
            accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
            accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
            accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
            accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
            accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
            accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
            accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
            _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
          }
        }
      }  // Inner half-block loop, unrolled, second iteration.
    }    // End row-block loop.

    // The unrolling within this conditional may be somewhat pointless. It
    // depends on the kinds of models.
    if (row < end_row) {
      const int residual_rows = end_row - row;

      const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
      float* dst_ptr = dst_col_ptr + row;
      const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;

      // Initialize with bias.
      const __mmask16 row_mask =
          (static_cast<std::uint32_t>(1) << residual_rows) - 1;
      const __m512 initial_accum_data =
          _mm512_maskz_loadu_ps(row_mask, bias_ptr);

      // Process block in two halves, split by columns.
      for (int mmm = 0; mmm < 2; ++mmm) {
        __m512 accum_data_v0 = initial_accum_data;
        __m512 accum_data_v1 = initial_accum_data;
        __m512 accum_data_v2 = initial_accum_data;
        __m512 accum_data_v3 = initial_accum_data;
        __m512 accum_data_v4 = initial_accum_data;
        __m512 accum_data_v5 = initial_accum_data;
        __m512 accum_data_v6 = initial_accum_data;
        __m512 accum_data_v7 = initial_accum_data;

        const float* lhs_ptr = lhs_col_ptr;
        const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
        for (int d = 0; d < (params.depth - 1); ++d) {
          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
          lhs_ptr += 16;
          rhs_ptr += 16;
          {
            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
            accum_data_v0 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
            accum_data_v1 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
            accum_data_v2 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
            accum_data_v3 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
            accum_data_v4 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
            accum_data_v5 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
            accum_data_v6 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
            accum_data_v7 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
          }
        }
        {
          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);
          {
            const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
            accum_data_v0 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
            const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
            accum_data_v1 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
            const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
            accum_data_v2 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
            const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
            accum_data_v3 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
            const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
            accum_data_v4 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
            const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
            accum_data_v5 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
            const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
            accum_data_v6 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
            const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
            accum_data_v7 =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
          }
          {
            float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
            accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
            accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
            _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask,
                                  accum_data_v0);
            accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
            accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
            _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask,
                                  accum_data_v1);
            accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
            accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
            _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask,
                                  accum_data_v2);
            accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
            accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
            _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask,
                                  accum_data_v3);
            accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
            accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
            _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask,
                                  accum_data_v4);
            accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
            accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
            _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask,
                                  accum_data_v5);
            accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
            accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
            _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask,
                                  accum_data_v6);
            accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
            accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
            _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask,
                                  accum_data_v7);
          }
        }
      }  // Inner half-block loop.
    }    // Residual rows, main col-block loop.
  }      // End col-block loop.

  if (col < end_col) {
    RUY_DCHECK_GE(end_col - col, 0);
    RUY_DCHECK_LT(end_col - col, 16);

    __m512 accum_data_v[8];

    const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
    float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;

    for (int row = params.start_row; row < end_row; row += 16) {
      const int residual_rows = std::min(end_row - row, 16);

      const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
      float* dst_ptr = dst_col_ptr + row;
      const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;

      // Initialize with bias.
      const __mmask16 row_mask =
          (static_cast<std::uint32_t>(1) << residual_rows) - 1;
      const __m512 initial_accum_data =
          _mm512_maskz_loadu_ps(row_mask, bias_ptr);

      // Process block in two halves, split by columns.
      for (int mmm = 0; mmm < 2; ++mmm) {
        for (int j = 0; j < 8; ++j) {
          accum_data_v[j] = initial_accum_data;
        }

        const float* lhs_ptr = lhs_col_ptr;
        const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
        for (int d = 0; d < params.depth; ++d) {
          const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
          const __m256 rhs_data = _mm256_loadu_ps(rhs_ptr);

          for (int j = 0; j < 8; ++j) {
            const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]);
            accum_data_v[j] =
                _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
          }
          lhs_ptr += 16;
          rhs_ptr += 16;
        }

        const int residual_cols = std::min(end_col - col - 8 * mmm, 8);

        if (residual_rows == 16) {
          if (residual_cols == 8) {
            for (int j = 0; j < 8; ++j) {
              float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
              accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
              accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
              _mm512_storeu_ps(block_ptr, accum_data_v[j]);
            }
          } else {
            for (int j = 0; j < residual_cols; ++j) {
              float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
              accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
              accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
              _mm512_storeu_ps(block_ptr, accum_data_v[j]);
            }
          }
        } else {
          for (int j = 0; j < residual_cols; ++j) {
            float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
            accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
            accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
            _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]);
          }
        }
      }  // Inner half-block loop.
    }    // End row-block loop.
  }      // Residual cols.
}