void fbgemmGroupwiseConv()

in src/GroupwiseConv.cc [976:1272]


void fbgemmGroupwiseConv(
    const conv_param_t<SPATIAL_DIM>& conv_param,
    const uint8_t* activations,
    int32_t a_zero_point,
    int32_t* rowOffsetBuf,
    packed_W& packed_weights,
    outType* out,
    int32_t* outBuffer,
    const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
    int thread_id,
    int num_threads) {
  using processOutputType = ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>;

  if (!cpuinfo_initialize()) {
    throw runtime_error("Failed to initialize cpuinfo!");
  }

  int MB = conv_param.MB;
  int OT = SPATIAL_DIM <= 2 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 3];
  int OH = SPATIAL_DIM == 1 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 2];
  int OW = conv_param.OUT_DIM[SPATIAL_DIM - 1];
  int T = SPATIAL_DIM <= 2 ? 1 : conv_param.K[SPATIAL_DIM - 3];
  int R = SPATIAL_DIM == 1 ? 1 : conv_param.K[SPATIAL_DIM - 2];
  int S = conv_param.K[SPATIAL_DIM - 1];
  int G = conv_param.G;
  int OC = conv_param.OC;
  int IC = conv_param.IC;
  int K_per_G = conv_param.OC / G;
  int C_per_G = conv_param.IC / G;
  int OH_OW = OH * OW;
  int OT_OH_OW = OT * OH * OW;
  int IT = SPATIAL_DIM <= 2 ? 1 : conv_param.IN_DIM[SPATIAL_DIM - 3];
  int IH = SPATIAL_DIM == 1 ? 1 : conv_param.IN_DIM[SPATIAL_DIM - 2];
  int IW = conv_param.IN_DIM[SPATIAL_DIM - 1];
  int IH_IW = IH * IW;
  int IT_IH_IW = IT * IH * IW;
  int paddedCPerG = (C_per_G + 3) / 4 * 4;

  bool b_symmetric = (Q_GRAN == QuantizationGranularity::TENSOR &&
                      outProcess.getBZeroPoint()[0] == 0) ||
      rowOffsetBuf == nullptr;
  int G_together = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>::
      numOfGroupsTogether(conv_param);

  if (SPATIAL_DIM == 1) {
    throw std::runtime_error("Groupwise 1D not implemented!");
  }
  if (SPATIAL_DIM == 2) {
    // Parallelization:
    int64_t batch_start = 0;
    int64_t batch_end = MB;
    int64_t oh_start = 0;
    int64_t oh_end = OH;
    if (MB >= num_threads) {
      fbgemmPartition1D(thread_id, num_threads, MB, batch_start, batch_end);
    } else {
      fbgemmPartition1D(thread_id, num_threads, OH, oh_start, oh_end);
    }

    if (batch_start >= batch_end || oh_start >= oh_end) {
      // There is no work for this thread
      return;
    }

    // generate convolution  + rowOffset kernel
    bool calculateRowOffset = !b_symmetric;
    bool isTopEdgeIncluded = oh_start == 0;
    bool isBottomEdgeIncluded = oh_end == OH;
    bool isTopBottomEdgeSame =
        isTopEdgeIncluded && isBottomEdgeIncluded && oh_end == oh_start + 1;
    jit_conv_kernel_fp fpConv = getOrCreateConvKernel<SPATIAL_DIM>(
        conv_param,
        a_zero_point,
        calculateRowOffset,
        isTopEdgeIncluded,
        isBottomEdgeIncluded,
        isTopBottomEdgeSame,
        false);

    int ih_start = 0;
    if (oh_start > 0) {
      ih_start = -conv_param.pad[SPATIAL_DIM - 2] +
          oh_start * conv_param.stride[SPATIAL_DIM - 2];
    }
    int32_t* out_start = outBuffer + oh_start * OW * OC;
    const uint8_t* in_start = activations + ih_start * IW * IC;
    int32_t* rowOffsetBuf_start =
        rowOffsetBuf ? rowOffsetBuf + oh_start * OW * G_together : nullptr;
    for (int i = batch_start; i < batch_end; ++i) {
      const uint8_t* in_start_batch = in_start + i * IH_IW * conv_param.IC;
      int32_t* out_start_batch = out_start + i * OH_OW * OC;
      int32_t* rowOffsetBuf_start_batch =
          rowOffsetBuf ? rowOffsetBuf_start + i * OH_OW * G_together : nullptr;
      for (int g = 0; g < G; g += G_together) {
        const uint8_t* in_start_group = in_start_batch + g * C_per_G;
        int8_t* weight_start =
            packed_weights.getBuf() + g * R * S * K_per_G * paddedCPerG;
        int32_t* out_start_group = out_start_batch;
        int32_t* rowOffsetBuf_start_group = rowOffsetBuf_start_batch;
        // Uncomment the following two lines to stop
        // reuse of output and rowoffset buffer
        // out_start_group = out_start_batch + g * K_per_G;
        // rowOffsetBuf_start_group = rowOffsetBuf_start_batch + g * MB * OH_OW;

        // exactly the same compute as the JIT'ed below
        // kernel_compute(
        //    conv_param,
        //    in_start_group,
        //    weight_start,
        //    out_start_group,
        //    a_zero_point,
        //    oh_start,
        //    oh_end,
        //    OW,
        //    rowOffsetBuf_start_group);

        fpConv(
            in_start_group,
            weight_start,
            out_start_group,
            a_zero_point,
            oh_start,
            oh_end,
            OW,
            rowOffsetBuf_start_group);

        const int32_t* inp = out_start_group;
        block_type_t block{
            static_cast<int>(i * OT_OH_OW + oh_start * OW),
            static_cast<int>((oh_end - oh_start) * OW),
            g * K_per_G,
            G_together * K_per_G};
        int ld_out = G * K_per_G;
        int ld_in = G * K_per_G;

        dispatchOutputProcessing(
            outProcess,
            rowOffsetBuf_start_group,
            out,
            inp,
            block,
            ld_out,
            ld_in,
            G,
            C_per_G,
            is_requantization<processOutputType>());
      } // for each g
    } // for each i
  } else {
    assert(SPATIAL_DIM == 3 && "Unsupported SPATIAL_DIM");

    conv_param_t<> conv_p_2d(
        conv_param.MB,
        conv_param.IC,
        conv_param.OC,
        {conv_param.IN_DIM[SPATIAL_DIM - 2],
         conv_param.IN_DIM[SPATIAL_DIM - 1]},
        conv_param.G,
        {conv_param.K[SPATIAL_DIM - 2], conv_param.K[SPATIAL_DIM - 1]},
        {conv_param.stride[SPATIAL_DIM - 2],
         conv_param.stride[SPATIAL_DIM - 1]},
        {conv_param.pad[1],
         conv_param.pad[2],
         conv_param.pad[4],
         conv_param.pad[5]});

    // Parallelization:
    int64_t batch_start = 0;
    int64_t batch_end = MB;
    int64_t oh_start = 0;
    int64_t oh_end = OH;
    if (MB >= num_threads) {
      fbgemmPartition1D(thread_id, num_threads, MB, batch_start, batch_end);
    } else {
      fbgemmPartition1D(thread_id, num_threads, OH, oh_start, oh_end);
    }

    if (batch_start >= batch_end || oh_start >= oh_end) {
      // There is no work for this thread
      return;
    }

    // generate convolution  + rowOffset kernel
    bool calculateRowOffset = !b_symmetric;
    bool isTopEdgeIncluded = oh_start == 0;
    bool isBottomEdgeIncluded = oh_end == OH;
    bool isTopBottomEdgeSame =
        isTopEdgeIncluded && isBottomEdgeIncluded && oh_end == oh_start + 1;
    jit_conv_kernel_fp fpConvNoAccum = getOrCreateConvKernel<2>(
        conv_p_2d,
        a_zero_point,
        calculateRowOffset,
        isTopEdgeIncluded,
        isBottomEdgeIncluded,
        isTopBottomEdgeSame,
        false);
    jit_conv_kernel_fp fpConvAccum = getOrCreateConvKernel<2>(
        conv_p_2d,
        a_zero_point,
        calculateRowOffset,
        isTopEdgeIncluded,
        isBottomEdgeIncluded,
        isTopBottomEdgeSame,
        true);
    jit_conv_kernel_fp fpConv;

    int ih_start = 0;
    if (oh_start > 0) {
      ih_start = -conv_p_2d.pad[0] + oh_start * conv_p_2d.stride[0];
    }

    vector<uint8_t> zero_points(IH * IW * IC, a_zero_point);
    int32_t* out_start = outBuffer + oh_start * OW * OC;
    const uint8_t* in_start = activations + ih_start * IW * IC;
    int32_t* rowOffsetBuf_start =
        rowOffsetBuf ? rowOffsetBuf + oh_start * OW * G_together : nullptr;
    for (int i = batch_start; i < batch_end; ++i) {
      const uint8_t* in_start_batch = in_start + i * IT_IH_IW * IC;
      int32_t* out_start_batch = out_start + i * OT_OH_OW * OC;
      int32_t* rowOffsetBuf_start_batch = rowOffsetBuf
          ? rowOffsetBuf_start + i * OT_OH_OW * G_together
          : nullptr;
      for (int g = 0; g < G; g += G_together) {
        const uint8_t* in_start_group = in_start_batch + g * C_per_G;
        int8_t* weight_start =
            packed_weights.getBuf() + g * T * R * S * K_per_G * paddedCPerG;
        int32_t* out_start_group = out_start_batch;
        int32_t* rowOffsetBuf_start_group = rowOffsetBuf_start_batch;
        // Uncomment the following two lines to stop
        // reuse of output and rowoffset buffer
        // out_start_group = out_start_batch + g * K_per_G;
        // rowOffsetBuf_start_group = rowOffsetBuf_start_batch + g * MB *
        // OT_OH_OW;

        for (int ot = 0; ot < OT; ++ot) {
          int32_t* out_start_t = out_start_group + ot * OH_OW * OC;
          int32_t* rowOffsetBuf_start_t = rowOffsetBuf
              ? rowOffsetBuf_start_group + ot * OH_OW * G_together
              : nullptr;
          for (int t = 0; t < T; ++t) {
            int t_in = -conv_param.pad[0] + ot * conv_param.stride[0] + t;
            const uint8_t* in_start_t = in_start_group + t_in * IH_IW * IC;
            int8_t* weight_start_t =
                weight_start + t * R * S * K_per_G * G_together * paddedCPerG;
            if (t_in < 0 || t_in >= IT) {
              in_start_t = zero_points.data();
            }
            // exactly the same compute as the JIT'ed below
            // kernel_compute(
            // conv_p_2d,
            // in_start_t,
            // weight_start_t,
            // out_start_t,
            // a_zero_point,
            // oh_start,
            // oh_end,
            // OW,
            // rowOffsetBuf_start_t,
            // t > 0);

            fpConv = t > 0 ? fpConvAccum : fpConvNoAccum;
            fpConv(
                in_start_t,
                weight_start_t,
                out_start_t,
                a_zero_point,
                oh_start,
                oh_end,
                OW,
                rowOffsetBuf_start_t);
          }

          const int32_t* inp = out_start_t;
          block_type_t block{
              static_cast<int>(i * OT_OH_OW + oh_start * OW),
              static_cast<int>((oh_end - oh_start) * OW),
              g * K_per_G,
              G_together * K_per_G};
          int ld_out = G * K_per_G;
          int ld_in = G * K_per_G;

          dispatchOutputProcessing(
              outProcess,
              rowOffsetBuf_start_t,
              out + ot * OH_OW * OC,
              inp,
              block,
              ld_out,
              ld_in,
              G,
              C_per_G,
              is_requantization<processOutputType>());
        } // for each ot
      } // for each g
    } // for each i
  } // SPATIAL_DIM == 3
}