void fbgemmDirectConv()

in src/PackWeightsForDirectConv.cc [226:457]


void fbgemmDirectConv(
    const conv_param_t<SPATIAL_DIM>& conv_p,
    const uint8_t* Aint8,
    PackedDirectConvMatrix& Bint8_tr,
    uint8_t* C,
    int32_t* C_buffer,
    const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
    const BIAS_TYPE* bias,
    // const int32_t* bias,
    int thread_id,
    int num_threads) {
  // support for single thread now,
  // will enable multithread later
  if (thread_id > 0 || thread_id >= num_threads) {
    return;
  }

  if (SPATIAL_DIM != 2) {
    assert(false && "1d/3d direct conv not supported");
  } else {
    if (conv_p.transposed) {
      DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
          jit_micro_kernel_fp_convT fn;
      DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
      /*
         fn = codeObj.getOrCreateDirectConvTrans<inst_set_t::avx2>(
         true, conv_p.stride[1]);
         */
      fn = codeObj.getOrCreateDirectConvTrans<inst_set_t::avx2>(
          true, conv_p.stride[1], conv_p.K[1]);

      int32_t* inSum = static_cast<int32_t*>(fbgemmAlignedAlloc(
          64, conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * sizeof(int32_t)));
      int32_t* rowSum = static_cast<int32_t*>(fbgemmAlignedAlloc(
          64, conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * sizeof(int32_t)));

      directConvRowSum(conv_p, Aint8, inSum, rowSum);
      int kernel_dim = conv_p.K[0] * conv_p.K[1];

      std::memset(
          C_buffer,
          0,
          sizeof(int32_t) * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC);
      std::memset(
          C,
          0,
          sizeof(int8_t) * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC);
      // no-op output process objects
      for (int i = 0; i < conv_p.OC; i += 8) {
        for (int j = 0; j < conv_p.IN_DIM[0]; j++) {
          fn(Aint8 + j * conv_p.IC * conv_p.IN_DIM[1],
             Bint8_tr.PackedMat() + i * kernel_dim * conv_p.IC,
             C_buffer + j * conv_p.OUT_DIM[1] * conv_p.OC + i,
             conv_p.IC,
             conv_p.OC,
             (conv_p.OC * conv_p.OUT_DIM[1] - conv_p.OC * conv_p.K[1]) * 4,
             conv_p.IN_DIM[1]);
        }
      }

      int32_t A_zero_point = outProcess.getAZeroPoint();
      const int32_t* B_zero_point = outProcess.getBZeroPoint();
      // const float* C_multiplier = outProcess.getCMultiplier();
      const int32_t* col_offsets = outProcess.getColOffsets();

      /*
      int groups = 1;
      if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
        groups = conv_p.OC;
      }
      */
      requantizationParams_t<BIAS_TYPE> reqObj = {
          outProcess.getAZeroPoint(),
          outProcess.getBZeroPoint(),
          outProcess.getCZeroPoint(),
          outProcess.getCMultiplier(),
          rowSum, // rowOffsetBuf,
          outProcess.getColOffsets(),
          (outProcess.getBias()),
          static_cast<std::uint32_t>(conv_p.OC), // outProcess.getNCols(),
          1, // groups
          outProcess.getActWScale()};

      // Dispatch HAS_BIAS
      if (bias == nullptr) {
        // Dispatch A_SYMMETRIC and B_SYMMETRIC
        if (A_zero_point == 0 || col_offsets == nullptr) {
          if (Q_GRAN == QuantizationGranularity::TENSOR &&
              B_zero_point[0] == 0) {
            requantizeOutputProcessingAvx2<
                true,
                true,
                QuantizationGranularity::TENSOR,
                false, // HAS_BIAS,
                FUSE_RELU,
                BIAS_TYPE,
                true>(
                C,
                C_buffer,
                {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
                conv_p.OC,
                conv_p.OC,
                reqObj);
          } else {
            requantizeOutputProcessingAvx2<
                true,
                false,
                Q_GRAN,
                false, // HAS_BIAS,
                FUSE_RELU,
                BIAS_TYPE,
                true>(
                C,
                C_buffer,
                {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
                conv_p.OC,
                conv_p.OC,
                reqObj);
          }
        } else {
          if (Q_GRAN == QuantizationGranularity::TENSOR &&
              B_zero_point[0] == 0) {
            requantizeOutputProcessingAvx2<
                false,
                true,
                QuantizationGranularity::TENSOR,
                false, // HAS_BIAS,
                FUSE_RELU,
                BIAS_TYPE,
                true>(
                C,
                C_buffer,
                {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
                conv_p.OC,
                conv_p.OC,
                reqObj);
          } else {
            requantizeOutputProcessingAvx2<
                false,
                false,
                Q_GRAN,
                false, // HAS_BIAS,
                FUSE_RELU,
                BIAS_TYPE,
                true>(
                C,
                C_buffer,
                {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
                conv_p.OC,
                conv_p.OC,
                reqObj);
          }
        }
      } else { // has_bias == true

        // dispatch A_SYMMETRIC and B_SYMMETRIC
        if (A_zero_point == 0 || col_offsets == nullptr) {
          if (Q_GRAN == QuantizationGranularity::TENSOR &&
              B_zero_point[0] == 0) {
            requantizeOutputProcessingAvx2<
                true,
                true,
                QuantizationGranularity::TENSOR,
                true, // HAS_BIAS,
                FUSE_RELU,
                BIAS_TYPE,
                true>(
                C,
                C_buffer,
                {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
                conv_p.OC,
                conv_p.OC,
                reqObj);
          } else {
            requantizeOutputProcessingAvx2<
                true,
                false,
                Q_GRAN,
                true, // HAS_BIAS,
                FUSE_RELU,
                BIAS_TYPE,
                true>(
                C,
                C_buffer,
                {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
                conv_p.OC,
                conv_p.OC,
                reqObj);
          }
        } else {
          if (Q_GRAN == QuantizationGranularity::TENSOR &&
              B_zero_point[0] == 0) {
            requantizeOutputProcessingAvx2<
                false,
                true,
                QuantizationGranularity::TENSOR,
                true, // HAS_BIAS,
                FUSE_RELU,
                BIAS_TYPE,
                true>(
                C,
                C_buffer,
                {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
                conv_p.OC,
                conv_p.OC,
                reqObj);
          } else {
            requantizeOutputProcessingAvx2<
                false,
                false,
                Q_GRAN,
                true, // HAS_BIAS,
                FUSE_RELU,
                BIAS_TYPE,
                true>(
                C,
                C_buffer,
                {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC},
                conv_p.OC,
                conv_p.OC,
                reqObj);
          }
        }
      }
      fbgemmAlignedFree(inSum);
      fbgemmAlignedFree(rowSum);
    } // transposed conv
    else { // non-transposed conv
      assert(false && "non-transposed direct conv not integrated yet.");
    }
  } // else SPATIAL_DIM
}