bool BatchMatMulDNNLowPOp::RunOnDevice()

in caffe2/quantization/server/batch_matmul_dnnlowp_op.cc [45:744]


bool BatchMatMulDNNLowPOp<T>::RunOnDevice() {
  this->ParseDNNLowPOperatorArguments_();

  const auto& A = InputTensorCPU_(0);
  const auto& B = InputTensorCPU_(1);
  auto* Y = OutputTensorCPU_(0);

  auto ndims_A = A.ndim();
  auto dims_A = A.sizes().vec();
  auto ndims_B = B.ndim();
  auto dims_B = B.sizes().vec();

  auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) {
    std::stringstream ss;
    ss << "Inputs with dimensions A = ";
    ss << dim1;
    ss << " and B = ";
    ss << dim2;
    ss << " is not supported with broadcast=0. Did you forget to set the "
          "broadcast flag?";
    return ss.str();
  };

  // These should all be false if we're not broadcasting.
  bool dimMismatch = ndims_A != ndims_B;
  bool dimsLessThan1D = ndims_A < 2;
  CAFFE_ENFORCE(
      broadcast_ || (!dimMismatch && !dimsLessThan1D),
      noBroadcastErrorMsg(ndims_A, ndims_B));

  auto dimMismatchErrorString = [](size_t dimnum1,
                                   size_t dim1,
                                   size_t dimnum2,
                                   size_t dim2,
                                   bool trans_a,
                                   bool trans_b) {
    std::stringstream ss;
    ss << "Expected dimension ";
    ss << dimnum1;
    ss << " of tensor A with value ";
    ss << dim1;
    ss << " to match dimension ";
    ss << dimnum2;
    ss << " of tensor B with value ";
    ss << dim2;
    ss << ". trans_a = ";
    ss << trans_a;
    ss << " trans_b = ";
    ss << trans_b;
    return ss.str();
  };

  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  int num_sub_batches, num_outer_batches;
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  size_t M, N, K;
  size_t A_stride = 1; // How far to increment A pointer each itr
  size_t B_stride = 1; // How far to increment B pointer each itr
  size_t Y_stride = 1; // How far to increment Y pointer each itr
  if (ndims_A == 1 && ndims_B == 1) {
    // vector-vector
    CAFFE_ENFORCE_EQ(
        dims_A[0],
        dims_B[0],
        "Vector-vector product requires each of the vectors to "
        "be the same size.");
    Y->Resize(1);
    num_sub_batches = 1;
    num_outer_batches = 1;
    M = 1;
    N = 1;
    K = dims_A[0];
  } else {
    bool A_broadcasted = false, B_broadcasted = false;
    if (ndims_A == 1) {
      dims_A.insert(dims_A.begin(), 1);
      ndims_A = 2;
      A_broadcasted = true;
    }
    if (ndims_B == 1) {
      dims_B.push_back(1);
      ndims_B = 2;
      B_broadcasted = true;
    }
    // matrix-matrix with batches
    // [B1..., M, K] * [B2..., K, N] -> [B..., M, N]
    // In the event that A or B are one-dimensional, the trailing or leading
    // 1 is not added to the output tensor's size.

    // First step: partition the tensors into inner and outer blocks.
    // Ignoring the last two dimensions of A and B, ensure that one of the
    // tensors' dimensions is a suffix of the other. For example,
    // [4, x, x] is a suffix of [2, 3, 4, x, x]. In this example, the
    // dimensions of size 2 and 3 will be broadcasted, so we partition into
    // 2*3=6 individual instances of batched GEMM with A and B \in [4, x, x].
    size_t num_inner_dims = std::min(ndims_A, ndims_B);
    for (size_t i = 2; i < num_inner_dims; ++i) {
      auto first_r_itr = dims_A.rbegin();
      auto second_r_itr = dims_B.rbegin();
      CAFFE_ENFORCE_EQ(
          *(first_r_itr + i),
          *(second_r_itr + i),
          dimMismatchErrorString(
              ndims_A - i - 1,
              *(first_r_itr + i),
              ndims_B - i - 1,
              *(second_r_itr + i),
              trans_a_,
              trans_b_));
    }
    size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;

    // Standard M, N, and K parameters respecting GEMM API and transpose
    // flags
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
    size_t K_dim;
    if (trans_a_) {
      M = dims_A[ndims_A - 1];
      K = dims_A[ndims_A - 2];
      K_dim = ndims_A - 2;
    } else {
      M = dims_A[ndims_A - 2];
      K = dims_A[ndims_A - 1];
      K_dim = ndims_A - 1;
    }
    if (trans_b_) {
      N = dims_B[ndims_B - 2];
      CAFFE_ENFORCE_EQ(
          K,
          dims_B[ndims_B - 1],
          dimMismatchErrorString(
              K_dim, K, ndims_B - 1, dims_B[ndims_B - 1], trans_a_, trans_b_));
    } else {
      N = dims_B[ndims_B - 1];
      CAFFE_ENFORCE_EQ(
          K,
          dims_B[ndims_B - 2],
          dimMismatchErrorString(
              K_dim, K, ndims_B - 2, dims_B[ndims_B - 2], trans_a_, trans_b_));
    }

    // Calculate output tensor shapes [B..., (M), (N)]
    // Batch dimensions will be broadcasted out to those of the longer tensor
    // A or B. Either M or N are optional if A or B, respectively are 1-D.
    std::vector<int64_t> new_dims;
    if (ndims_A >= ndims_B) {
      new_dims.assign(dims_A.begin(), dims_A.end() - 2);
    } else {
      new_dims.assign(dims_B.begin(), dims_B.end() - 2);
    }
    if (!A_broadcasted) {
      new_dims.push_back(M);
    } else {
      new_dims.push_back(1);
    }
    if (!B_broadcasted) {
      new_dims.push_back(N);
    } else {
      new_dims.push_back(1);
    }

    // Calculate strides. Continuing our example above,
    //   [4, M, K] * [2, 3, 4, K, N] = [2, 3, 4, M, N]
    // We calculate this as follows:
    //   1) Treat the outer batch dimensions as flattened, i.e. view the B
    //      tensor here as [6, 4, K, N] and Y as [6, 4, M, N]. The same rea-
    //      soning is analogous for the case where # dims A >= # dims B.
    //   2) Perform this operation:
    //        for i in range(6):
    //          Y[i, :, :, :] = BatchMatMul(A, B[i, :, :, :])
    A_stride = 1; // How far to increment A pointer each itr
    B_stride = 1; // How far to increment B pointer each itr
    Y_stride = 1; // How far to increment Y pointer each itr
    // How many "inner batches" we have. That is, the product of sizes for
    // the slices excluding M, K, and N, for their respective matrices.
    num_sub_batches = 1;
    if (ndims_A >= ndims_B) {
      auto first_r_itr = dims_A.rbegin();
      auto output_r_itr = new_dims.rbegin();
      for (size_t i = 0; i < num_inner_dims; ++i) {
        A_stride *= *(first_r_itr + i);
        Y_stride *= *(output_r_itr + i);
        if (i >= 2) {
          num_sub_batches *= *(first_r_itr + i);
        }
      }
      B_stride = 0;
    } else {
      A_stride = 0;
      auto second_r_itr = dims_B.rbegin();
      auto output_r_itr = new_dims.rbegin();
      for (size_t i = 0; i < num_inner_dims; ++i) {
        B_stride *= *(second_r_itr + i);
        Y_stride *= *(output_r_itr + i);
        if (i >= 2) {
          num_sub_batches *= *(second_r_itr + i);
        }
      }
    }

    num_outer_batches = 1;
    for (size_t i = 0; i < num_outer_dims; ++i) {
      num_outer_batches *= new_dims[i];
    }

    // Mutually exclusive since otherwise we would've taken the vector-vector
    // path above
    if (A_broadcasted) {
      new_dims.erase(new_dims.end() - 2);
    } else if (B_broadcasted) {
      new_dims.erase(new_dims.end() - 1);
    }

    // Allocate output tensor
    Y->Resize(new_dims);

    // Optimize case num_sub_batches == 1 where we can combine batched gemms
    // into a single gemm
    if (num_sub_batches == 1 && num_outer_batches > 1) {
      if (ndims_A > ndims_B && !trans_a_) {
        M *= num_outer_batches;
        num_outer_batches = 1;
      }
    }
  }

  // Zero batch dimension indicates no elements
  if (num_sub_batches == 0 || num_outer_batches == 0) {
    if (dequantize_output_) {
      Y->template mutable_data<float>();
    } else {
      Y->template mutable_data<T>();
    }
    return true;
  }

  // Choose quantization for X
  in_qparams_[0] = GetInputTensorQuantizationParamsOf(this, 0, qfactory_.get());
  int num_batches_B = B.numel() / (K * N);
  if (!first_invocation_ && !Bq_packed_.empty() &&
      num_batches_B * N != column_offsets_.size()) {
    LOG(INFO) << "Operator with output " << this->debug_def().output(0)
              << " does not have constant B";
    is_B_constant_ = false;
    Bq_packed_.clear();
  }
  bool fast_path =
      std::is_same<T, uint8_t>::value && GetCpuId().avx2() && is_B_constant_;

  if (fast_path) {
    // Quantize B
    if (Bq_packed_.empty()) {
      int signed_min = -(1 << (qfactory_->GetWeightPrecision() - 1));
      vector<int8_t> B_quantized_temp(K * N);
      column_offsets_.resize(num_batches_B * N);
      for (int i = 0; i < num_batches_B; ++i) {
        if (this->template InputIsType<int8::Int8TensorCPU>(1)) {
          // NOLINTNEXTLINE(modernize-use-emplace)
          B_qparams_.push_back(TensorQuantizationParams());
          B_qparams_[i].scale =
              this->template Input<int8::Int8TensorCPU>(1).scale;
          B_qparams_[i].zero_point =
              this->template Input<int8::Int8TensorCPU>(1).zero_point +
              signed_min;

          const T* B_data = B.template data<T>() + i * B_quantized_temp.size();
          // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
          for (auto j = 0; j < B_quantized_temp.size(); ++j) {
            B_quantized_temp[j] = B_data[j] + signed_min;
          }
        } else {
          B_qparams_.emplace_back(qfactory_->ChooseQuantizationParams(
              B.template data<float>() + i * B_quantized_temp.size(),
              B_quantized_temp.size(),
              true /* weight */));

          // B_qparams_[i] is computed for unsigned type.
          // Adjust for the fact that B will actually use signed.
          B_qparams_[i].zero_point += signed_min;

          fbgemm::Quantize<int8_t>(
              B.template data<float>() + i * B_quantized_temp.size(),
              B_quantized_temp.data(),
              B_quantized_temp.size(),
              B_qparams_[i]);
        }

        Bq_packed_.emplace_back(new fbgemm::PackBMatrix<int8_t>(
            trans_b_ ? fbgemm::matrix_op_t::Transpose
                     : fbgemm::matrix_op_t::NoTranspose,
            K,
            N,
            B_quantized_temp.data(),
            trans_b_ ? K : N,
            nullptr /*pmat*/,
            1)); /*groups*/

        // Pre-compute column_offset
        // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
        for (int j = 0; j < N; ++j) {
          int32_t sum = 0;
          if (trans_b_) {
            // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
            for (int k = 0; k < K; ++k) {
              sum += B_quantized_temp[j * K + k];
            }
          } else {
            // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
            for (int k = 0; k < K; ++k) {
              sum += B_quantized_temp[k * N + j];
            }
          }
          column_offsets_[i * N + j] = sum - B_qparams_[i].zero_point * K;
        }
      } // for each input in the batch
    } // Bq_packed_.empty()

    if (!dequantize_output_) {
      GetOutputQuantizationParams_();

      for (int i = 0; i < num_batches_B; ++i) {
        float real_multiplier =
            in_qparams_[0].scale * B_qparams_[i].scale / out_qparams_.scale;
        requantization_params_.emplace_back(
            qfactory_->ChooseRequantizationMultiplier(
                real_multiplier, out_qparams_));
      }
    } else {
      if (measure_quantization_error_) {
        // to measure quantization error, run ref impl.
        Fp32Op_()->DequantizeInput();
        Fp32Op_()->Get()->RunOnDevice();
      }
    }
  } else {
    // slow path
    if (first_invocation_) {
      string reason;
      if (!is_same<T, uint8_t>::value) {
        reason = "fbgemm only supports 8-bit integers";
      } else if (!GetCpuId().avx2()) {
        reason = "fbgemm only supports AVX2";
      } else if (!is_B_constant_) {
        reason = "B is not constant";
      } else {
        assert(false);
      }
      LOG(WARNING) << "BatchMatMul with output " << this->debug_def().output(0)
                   << " falls back to slow path because " << reason;
    }
    B_qparams_.resize(1);
    requantization_params_.resize(1);

    B_qparams_[0] =
        GetInputTensorQuantizationParamsOf(this, 1, qfactory_.get());

    GetOutputQuantizationParams_();

    float real_multiplier =
        in_qparams_[0].scale * B_qparams_[0].scale / out_qparams_.scale;
    requantization_params_[0] = qfactory_->ChooseRequantizationMultiplier(
        real_multiplier, out_qparams_);
  }

  first_invocation_ = false;

  vector<T> A_temp, B_temp;
  if (!Bq_packed_.empty()) {
    // fast path
    using namespace fbgemm;

    const T* A_quantized = nullptr;
    if (A.template IsType<T>() || !dequantize_output_) {
      // Only when input and output are float, we don't need input to be
      // quantized.
      A_quantized = QuantizeInputIfNeeded<T>(this, 0, in_qparams_[0], A_temp);
    }

#ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
    chrono::time_point<chrono::system_clock> t_begin, t_end;
    t_begin = chrono::system_clock::now();
#endif

    if (!dequantize_output_) {
      auto Y_data = Y->template mutable_data<T>();

      auto row_offset_len_per_thread =
          PackAWithRowOffset<uint8_t>::rowOffsetBufferSize();
      row_offsets_.resize(
          row_offset_len_per_thread * dnnlowp_get_max_threads());
      auto A_pack_buf_len_per_thread =
          PackAWithRowOffset<uint8_t>::packedBufferSize();
      A_pack_buf_.resize(A_pack_buf_len_per_thread * dnnlowp_get_max_threads());
      Y_int32_.resize(Y->numel());

#ifdef _OPENMP
#ifdef _MSC_VER
#pragma omp parallel for
#else
#pragma omp parallel for collapse(2)
#endif
#endif
      for (int p = 0; p < num_outer_batches; ++p) {
        for (int i = 0; i < num_sub_batches; ++i) {
          int tid = dnnlowp_get_thread_num();

          PackAWithRowOffset<uint8_t> packA(
              trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
              M,
              K,
              reinterpret_cast<const uint8_t*>(A_quantized) + p * A_stride +
                  i * M * K,
              trans_a_ ? M : K,
              A_pack_buf_.data() +
                  tid * A_pack_buf_len_per_thread, // buffer for packed matrix
              1, // group
              row_offsets_.data() + tid * row_offset_len_per_thread);

          int B_batch_idx = ndims_A >= ndims_B ? i : p * num_sub_batches + i;
          DoNothing<> doNothingObj{};
          ReQuantizeOutput<false /* FUSE_RELU */> outputProcObj(
              doNothingObj,
              &requantization_params_[B_batch_idx].real_multiplier,
              out_qparams_.zero_point,
              in_qparams_[0].zero_point,
              &B_qparams_[B_batch_idx].zero_point,
              packA.getRowOffsetBuffer(),
              column_offsets_.data() + B_batch_idx * N,
              nullptr, // bias
              N); // ncols per quant group

          fbgemmPacked(
              packA,
              *Bq_packed_[B_batch_idx],
              reinterpret_cast<uint8_t*>(Y_data) + p * Y_stride + i * M * N,
              Y_int32_.data() + p * Y_stride + i * M * N,
              N,
              outputProcObj,
              0, // thread_id
              1); // num_threads
        } // for each input in batch
      }

      PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
    } else {
      // dequantize_output
      float* Y_data = Y->template mutable_data<float>();

      if (!A.template IsType<T>()) {
        // Both input and output are float
        int row_offset_len_per_thread =
            PackAWithQuantRowOffset<uint8_t>::rowOffsetBufferSize();
        row_offsets_.resize(
            row_offset_len_per_thread * dnnlowp_get_max_threads());
        int A_pack_len_per_thread =
            PackAWithQuantRowOffset<uint8_t>::packedBufferSize();
        A_pack_buf_.resize(A_pack_len_per_thread * dnnlowp_get_max_threads());

#ifdef _OPENMP
#ifdef _MSC_VER
#pragma omp parallel for
#else
#pragma omp parallel for collapse(2)
#endif
#endif
        for (int p = 0; p < num_outer_batches; ++p) {
          for (int i = 0; i < num_sub_batches; ++i) {
            int tid = dnnlowp_get_thread_num();

            PackAWithQuantRowOffset<uint8_t> packA(
                trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
                M,
                K,
                A.template data<float>() + p * A_stride + i * M * K,
                trans_a_ ? M : K,
                A_pack_buf_.data() +
                    tid * A_pack_len_per_thread, // buffer for packed matrix
                in_qparams_[0].scale,
                in_qparams_[0].zero_point,
                1, // groups
                row_offsets_.data() + tid * row_offset_len_per_thread);

            int B_batch_idx = ndims_A >= ndims_B ? i : p * num_sub_batches + i;
            DoNothing<float, float> doNothingObj{};
            ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
                doNothingObj,
                in_qparams_[0].scale,
                &B_qparams_[B_batch_idx].scale,
                in_qparams_[0].zero_point,
                &B_qparams_[B_batch_idx].zero_point,
                packA.getRowOffsetBuffer(),
                column_offsets_.data() + B_batch_idx * N,
                nullptr, // bias
                N); // ncols per quant group

            fbgemmPacked(
                packA,
                *Bq_packed_[B_batch_idx],
                Y_data + p * Y_stride + i * M * N,
                reinterpret_cast<int32_t*>(Y_data) + p * Y_stride + i * M * N,
                N,
                outputProcObj,
                0, // thread_id
                1); // num_threads
          } // for each input in batch
        }
      } else {
        // Input quantized and output float
        auto row_offset_len_per_thread =
            PackAWithRowOffset<uint8_t>::rowOffsetBufferSize();
        row_offsets_.resize(
            row_offset_len_per_thread * dnnlowp_get_max_threads());
        auto A_pack_buf_len_per_thread =
            PackAWithRowOffset<uint8_t>::packedBufferSize();
        A_pack_buf_.resize(
            A_pack_buf_len_per_thread * dnnlowp_get_max_threads());

#ifdef _OPENMP
#ifdef _MSC_VER
#pragma omp parallel for
#else
#pragma omp parallel for collapse(2)
#endif
#endif
        for (int p = 0; p < num_outer_batches; ++p) {
          for (int i = 0; i < num_sub_batches; ++i) {
            int tid = dnnlowp_get_thread_num();

            PackAWithRowOffset<uint8_t> packA(
                trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
                M,
                K,
                reinterpret_cast<const uint8_t*>(A_quantized) + p * A_stride +
                    i * M * K,
                trans_a_ ? M : K,
                A_pack_buf_.data() +
                    tid * A_pack_buf_len_per_thread, // buffer for packed matrix
                1, // group
                row_offsets_.data() + tid * row_offset_len_per_thread);

            int B_batch_idx = ndims_A >= ndims_B ? i : p * num_sub_batches + i;
            DoNothing<float, float> doNothingObj{};
            ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
                doNothingObj,
                in_qparams_[0].scale,
                &B_qparams_[B_batch_idx].scale,
                in_qparams_[0].zero_point,
                &B_qparams_[B_batch_idx].zero_point,
                packA.getRowOffsetBuffer(),
                column_offsets_.data() + B_batch_idx * N,
                nullptr, // bias
                N); // ncols per quant group

            fbgemmPacked(
                packA,
                *Bq_packed_[B_batch_idx],
                Y_data + p * Y_stride + i * M * N,
                reinterpret_cast<int32_t*>(Y_data) + p * Y_stride + i * M * N,
                N,
                outputProcObj,
                0, // thread_id
                1); // num_threads
          } // for each input in batch
        }
      }
    } // dequantize_output

#ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
    t_end = chrono::system_clock::now();
    double dt = chrono::duration<double>(t_end - t_begin).count();
    double gops =
        2. * num_outer_batches * num_sub_batches * M * N * K / dt / 1e9;
    LOG(INFO) << "batches " << num_outer_batches * num_sub_batches << " m " << M
              << " n " << N << " k " << K << " " << gops << " gops";
#endif

    MeasureQuantizationError_();
  } else {
    // slow path
    // Quantize inputs
    const T* A_quantized =
        QuantizeInputIfNeeded<T>(this, 0, in_qparams_[0], A_temp);
    const T* B_quantized =
        QuantizeInputIfNeeded<T>(this, 1, B_qparams_[0], B_temp);

    T* Y_quantized = GetQuantizedOutputData_();
    Y_int32_.resize(Y->numel());
#ifdef _OPENMP
#ifdef _MSC_VER
#pragma omp parallel for
#else
#pragma omp parallel for collapse(2)
#endif
#endif
    for (int p = 0; p < num_outer_batches; ++p) {
      for (int i = 0; i < num_sub_batches; ++i) {
        // Y_q = (scale_A * scale_B) / scale_Y * Y_int32
        // Y_int32 = (A_q - zero_point_A * 1_A) * (B_q - zero_point_B * 1_B),
        //           where 1_A is a matrix with all 1s and same size as A
        // Y_int32 = A_q * B_q
        //           - zero_point_A * 1_A * B - zero_point_B * A * 1_B
        //           + zero_point_A * zero_point_B * 1_A * 1_B
        // zero_point_A * 1_A * B : a matrix with (i, j) is the sum of jth
        //                          column of B. This is computed by
        //                          column_offsets in the code.
        // zero_point_B * A * 1_B : a matrix with (i, j) is the sum of ith row
        //                          of A. This is computed by row_offset in the
        //                          code.
        // zero_point_A * zero_point_B * 1_A * 1_B : a matrix with all elements
        //                          are zero_point_A * zero_point_B *
        //                          num_of_cols_of_A. This is computed by
        //                          const_offset in the code.
        const T* A_quantized_i = A_quantized + p * A_stride + i * M * K;
        const T* B_quantized_i = B_quantized + p * B_stride + i * K * N;

        int32_t const_offset =
            in_qparams_[0].zero_point * B_qparams_[0].zero_point * K;
        vector<int32_t> column_offsets(N);
        // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
        for (int n = 0; n < N; ++n) {
          int32_t sum = 0;
          if (trans_b_) {
            // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
            for (int k = 0; k < K; ++k) {
              sum += B_quantized_i[k + n * K];
            }
          } else {
            // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
            for (int k = 0; k < K; ++k) {
              sum += B_quantized_i[k * N + n];
            }
          }
          column_offsets[n] = sum * in_qparams_[0].zero_point;
        }

        // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
        for (int m = 0; m < M; ++m) {
          int32_t row_offset = 0;
          if (trans_a_) {
            // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
            for (int k = 0; k < K; ++k) {
              row_offset += A_quantized_i[m + k * M];
            }
          } else {
            // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
            for (int k = 0; k < K; ++k) {
              row_offset += A_quantized_i[m * K + k];
            }
          }
          row_offset *= B_qparams_[0].zero_point;

          // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
          for (int n = 0; n < N; ++n) {
            int32_t sum = 0;
            if (!trans_a_ && !trans_b_) {
              // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
              for (int k = 0; k < K; ++k) {
                sum += static_cast<int32_t>(A_quantized_i[m * K + k]) *
                    static_cast<int32_t>(B_quantized_i[k * N + n]);
              }
            } else if (!trans_a_ && trans_b_) {
              // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
              for (int k = 0; k < K; ++k) {
                sum += static_cast<int32_t>(A_quantized_i[m * K + k]) *
                    static_cast<int32_t>(B_quantized_i[k + n * K]);
              }
            } else if (trans_a_ && !trans_b_) {
              // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
              for (int k = 0; k < K; ++k) {
                sum += static_cast<int32_t>(A_quantized_i[m + k * M]) *
                    static_cast<int32_t>(B_quantized_i[k * N + n]);
              }
            } else if (trans_a_ && trans_b_) {
              // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
              for (int k = 0; k < K; ++k) {
                sum += static_cast<int32_t>(A_quantized_i[m + k * M]) *
                    static_cast<int32_t>(B_quantized_i[k + n * K]);
              }
            }

            Y_int32_[p * Y_stride + i * M * N + m * N + n] =
                sum - row_offset - column_offsets[n] + const_offset;
          } // for each output col
        } // for each output row

        // Requantization
        // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
        for (int j = 0; j < M * N; ++j) {
          Y_quantized[p * Y_stride + i * M * N + j] = fbgemm::Requantize<T>(
              Y_int32_[p * Y_stride + i * M * N + j],
              requantization_params_[0]);
        }
      } // for each batch
    }

    RunOnDeviceEpilogue_();
  }

  return true;
}