int fbgemmConv()

in src/FbgemmConv.cc [117:417]


int fbgemmConv(
    const conv_param_t<SPATIAL_DIM>& conv_p,
    const std::uint8_t* activations,
    PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights,
    typename processOutputType::outType* out,
    std::int32_t* outBuffer,
    processOutputType& outProcess,
    int thread_id,
    int num_threads,
    const BlockingFactors* blocking_params) {
  if (!packed_weights.isPackingCompliant(conv_p)) {
    std::string msg =
        "[FBGEMM_CONV_ERROR] Convolution parameters "
        "mismatch between pre-packed weights and conv invocation! ";
    msg += packed_weights.mismatchingParams(conv_p);
    msg += std::string(
        " Please pack weights using the same parameters "
        "with which convolution operation is invoked!");
    throw std::logic_error(msg);
  }

  switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) {
    case optimized_conv_t::depthwise: {
      // 2D and 3D depthwise fast path
      // std::cout << "Depthwise fast path" << std::endl;
      const std::int32_t* B_zero_point = outProcess.getBZeroPoint();
      const float* C_multiplier = outProcess.getCMultiplier();
      const float* act_times_w_scale = outProcess.getActWScale();
      if (SPATIAL_DIM == 3) {
        static_assert(
            std::is_same<typename processOutputType::outType, std::uint8_t>::
                value,
            "For depthwise, only requantized output is supported");

        if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
          depthwise_3d_same_pad<QuantizationGranularity::TENSOR>(
              *reinterpret_cast<const conv_param_t<3>*>(&conv_p),
              outProcess.getAZeroPoint(),
              activations,
              B_zero_point,
              *(packed_weights.getPackedWForDepthwise()),
              C_multiplier,
              outProcess.getCZeroPoint(),
              out,
              outProcess.getColOffsets(),
              outProcess.getBias(),
              outProcess.RELU_FUSED, // fuse_relu
              act_times_w_scale,
              thread_id,
              num_threads);
        } else if (
            processOutputType::QGRANType == QuantizationGranularity::GROUP) {
          depthwise_3d_same_pad<QuantizationGranularity::GROUP>(
              *reinterpret_cast<const conv_param_t<3>*>(&conv_p),
              outProcess.getAZeroPoint(),
              activations,
              B_zero_point,
              *(packed_weights.getPackedWForDepthwise()),
              C_multiplier,
              outProcess.getCZeroPoint(),
              out,
              outProcess.getColOffsets(),
              outProcess.getBias(),
              outProcess.RELU_FUSED, // fuse_relu
              act_times_w_scale, // act_scale * weight_scale
              thread_id,
              num_threads);
        } else if (
            processOutputType::QGRANType ==
            QuantizationGranularity::OUT_CHANNEL) {
          depthwise_3d_same_pad<QuantizationGranularity::OUT_CHANNEL>(
              *reinterpret_cast<const conv_param_t<3>*>(&conv_p),
              outProcess.getAZeroPoint(),
              activations,
              B_zero_point,
              *(packed_weights.getPackedWForDepthwise()),
              C_multiplier,
              outProcess.getCZeroPoint(),
              out,
              outProcess.getColOffsets(),
              outProcess.getBias(),
              outProcess.RELU_FUSED, // fuse_relu
              act_times_w_scale, // act_scale * weight_scale
              thread_id,
              num_threads);
        } else {
          std::string msg =
              "[FBGEMM_CONV_ERROR] This quantization granularity is "
              "not supported";
          throw std::runtime_error(msg);
        }
      } else if (SPATIAL_DIM == 2) {
        if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
          depthwise_2d_same_pad<QuantizationGranularity::TENSOR>(
              conv_p.MB, // mini batch
              conv_p.IN_DIM[0], // H
              conv_p.IN_DIM[1], // W
              conv_p.IC, // input channels
              conv_p.OC, // output channels
              conv_p.stride[0], // stride_h
              conv_p.stride[1], // stride_w
              outProcess.getAZeroPoint(),
              activations,
              B_zero_point,
              *(packed_weights.getPackedWForDepthwise()),
              C_multiplier,
              outProcess.getCZeroPoint(),
              out,
              outProcess.getColOffsets(),
              outProcess.getBias(),
              outProcess.RELU_FUSED, // fuse_relu
              act_times_w_scale,
              thread_id,
              num_threads);
        } else if (
            processOutputType::QGRANType == QuantizationGranularity::GROUP) {
          depthwise_2d_same_pad<QuantizationGranularity::GROUP>(
              conv_p.MB, // mini batch
              conv_p.IN_DIM[0], // H
              conv_p.IN_DIM[1], // W
              conv_p.IC, // input channels
              conv_p.OC, // output channels
              conv_p.stride[0], // stride_h
              conv_p.stride[1], // stride_w
              outProcess.getAZeroPoint(),
              activations,
              B_zero_point,
              *(packed_weights.getPackedWForDepthwise()),
              C_multiplier,
              outProcess.getCZeroPoint(),
              out,
              outProcess.getColOffsets(),
              outProcess.getBias(),
              outProcess.RELU_FUSED, // fuse_relu
              act_times_w_scale, // act_scale * weight_scale
              thread_id,
              num_threads);
        } else if (
            processOutputType::QGRANType ==
            QuantizationGranularity::OUT_CHANNEL) {
          // The number of input channels == groups for depthwise convolutions
          depthwise_2d_same_pad<QuantizationGranularity::OUT_CHANNEL>(
              conv_p.MB, // mini batch
              conv_p.IN_DIM[0], // H
              conv_p.IN_DIM[1], // W
              conv_p.IC, // input channels
              conv_p.OC, // output channels
              conv_p.stride[0], // stride_h
              conv_p.stride[1], // stride_w
              outProcess.getAZeroPoint(),
              activations,
              B_zero_point,
              *(packed_weights.getPackedWForDepthwise()),
              C_multiplier,
              outProcess.getCZeroPoint(),
              out,
              outProcess.getColOffsets(),
              outProcess.getBias(),
              outProcess.RELU_FUSED, // fuse_relu
              act_times_w_scale, // act_scale * weight_scale
              thread_id,
              num_threads);
        } else {
          std::string msg =
              "[FBGEMM_CONV_ERROR] This quantization granularity is "
              "not supported";
          throw std::runtime_error(msg);
        }
      } else {
        std::string msg =
            "[FBGEMM_CONV_ERROR] This spatial dim is not supported";
        throw std::runtime_error(msg);
      }
      break;
    }
    case optimized_conv_t::groupwise: {
      // optimized groupwise convolution
      // std::cout << "Groupwise fast path" << std::endl;
      std::vector<int32_t> row_offset_buf(
          rowOffsetBufferSizeGConv<SPATIAL_DIM>(conv_p));
      outProcess.setRowOffsets(row_offset_buf.data());
      fbgemmGroupwiseConv(
          conv_p,
          activations,
          outProcess.getAZeroPoint(),
          row_offset_buf.data(),
          *(packed_weights.getPackedWForGroupwise()),
          out,
          outBuffer,
          outProcess,
          thread_id,
          num_threads);
      break;
    }
    case optimized_conv_t::pointwise: {
      std::vector<int32_t> row_offset_buf(
          PackAWithRowOffset<uint8_t>::rowOffsetBufferSize(blocking_params));
      int image_dim = std::accumulate(
          conv_p.IN_DIM.begin(),
          conv_p.IN_DIM.end(),
          1,
          std::multiplies<int>());
      PackAWithRowOffset<uint8_t, ACC_T> packA(
          matrix_op_t::NoTranspose,
          conv_p.MB * image_dim,
          conv_p.IC,
          activations,
          conv_p.IC,
          nullptr,
          conv_p.G,
          row_offset_buf.data(),
          blocking_params);

      outProcess.setRowOffsets(row_offset_buf.data());
      fbgemmPacked(
          packA,
          *(packed_weights.getPackedWForPointwise()),
          out,
          outBuffer,
          conv_p.OC,
          outProcess,
          thread_id,
          num_threads,
          blocking_params);
      break;
    }
    case optimized_conv_t::directconv: {
      // specialized direct convolution path
      // std::cout << "Directconv fast path" << std::endl;
      fbgemmDirectConv<SPATIAL_DIM, processOutputType::QGRANType>(
          conv_p,
          // Aint8,
          activations,
          *(packed_weights.getPackedWForDirectconv()),
          out,
          outBuffer,
          outProcess,
          outProcess.getBias(),
          thread_id,
          num_threads);
      break;
    }
    case optimized_conv_t::fastpath1d: {
      break;
    }
    case optimized_conv_t::im2col: {
      // All other convolutions go through im2col-based implementation
      // std::cout << "Im2col path" << std::endl;
      std::vector<int32_t> row_offset_buf(
          PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>::rowOffsetBufferSize(
              blocking_params));

      const std::int32_t* b_zero_point = outProcess.getBZeroPoint();
      bool b_symmetric = false;
      if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
        b_symmetric = b_zero_point[0] == 0;
      } else if (
          processOutputType::QGRANType == QuantizationGranularity::GROUP) {
        b_symmetric =
            std::all_of(b_zero_point, b_zero_point + conv_p.G, [](int i) {
              return i == 0;
            });
      } else if (
          processOutputType::QGRANType ==
          QuantizationGranularity::OUT_CHANNEL) {
        b_symmetric =
            std::all_of(b_zero_point, b_zero_point + conv_p.OC, [](int i) {
              return i == 0;
            });
      } else {
        std::string msg =
            "[FBGEMM_CONV_ERROR] This quantization granularity is "
            "not supported";
        throw std::runtime_error(msg);
      }
      PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM> packA(
          conv_p,
          activations,
          nullptr, /* buffer for packed matrix */
          outProcess.getAZeroPoint(),
          row_offset_buf.data(),
          b_symmetric,
          blocking_params);

      outProcess.setRowOffsets(row_offset_buf.data());
      fbgemmPacked(
          packA,
          *(packed_weights.getPackedWForIm2col()),
          out,
          outBuffer,
          conv_p.OC,
          outProcess,
          thread_id,
          num_threads,
          blocking_params);
      break;
    }
  } // switch

  return 0;
}