static ALWAYS_INLINE void depthwise_2d_()

in src/FbgemmI8Depthwise2DAvx2-inl.h [124:489]


static ALWAYS_INLINE void depthwise_2d_(
    int N,
    int H,
    int W,
    int IC,
    int OC,
    int stride_h,
    int stride_w,
    std::int32_t A_zero_point,
    const std::uint8_t* A,
    const std::int32_t* B_zero_point,
    const PackedDepthWiseConvMatrix& B,
    const float* C_multiplier,
    std::int32_t C_zero_point,
    std::int32_t* C_int32,
    std::uint8_t* C_uint8,
    const std::int32_t* col_offsets,
    const BIAS_TYPE* bias,
    const float* act_times_w_scale,
    int thread_id,
    int num_threads) {
  assert(IC % 8 == 0);
  constexpr int R = S;
  constexpr int64_t PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2,
                    PAD_R = PAD_L;
  int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
  int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
  const std::int8_t* Bp = B.PackedMat();

  int32_t* row_offsets = static_cast<int32_t*>(
      fbgemmAlignedAlloc(64, (IC + 31) / 32 * 32 * sizeof(int32_t)));

  int64_t n_begin, n_end, h_begin, h_end, w_begin, w_end;
  // Reuse the 3-dim partition scheme for parallelization in matrix
  // multiplication.
  thread_type_t th_info =
      fbgemmGetThreadPartition(N, H_OUT, W_OUT, thread_id, num_threads);
  // Calculate the begin and end index along the batch (N) dimension
  fbgemmPartition1D(
      th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end);
  // Calculate the begin and end index along the H dimension
  fbgemmPartition1D(
      th_info.m_thread_id, th_info.m_num_threads, H_OUT, h_begin, h_end);
  // Calculate the begin and end index along the W dimension
  fbgemmPartition1D(
      th_info.n_thread_id, th_info.n_num_threads, W_OUT, w_begin, w_end);

  GenI8Depthwise::jit_kernel_signature middle_kernel;

  for (int n = n_begin; n < n_end; ++n) {
    const std::uint8_t* A_base = A + n * H * W * IC;
    std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * OC;

    int h = 0;
    int w = 0;

    for (h = h_begin; h < std::min(PAD_T, h_end); ++h) {
      for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
        depthwise_2d_kernel_<
            S,
            FUSE_RELU,
            HAS_BIAS,
            A_SYMMETRIC,
            B_SYMMETRIC,
            Q_GRAN>(
            H,
            W,
            IC,
            OC,
            h,
            w,
            stride_h,
            stride_w,
            A_zero_point,
            A_base,
            B_zero_point,
            Bp,
            C_multiplier,
            C_zero_point,
            C_int32,
            C_uint8_base,
            row_offsets,
            col_offsets,
            bias,
            act_times_w_scale);
      }

      for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
        depthwise_2d_kernel_<
            S,
            FUSE_RELU,
            HAS_BIAS,
            A_SYMMETRIC,
            B_SYMMETRIC,
            Q_GRAN>(
            H,
            W,
            IC,
            OC,
            h,
            w,
            stride_h,
            stride_w,
            A_zero_point,
            A_base,
            B_zero_point,
            Bp,
            C_multiplier,
            C_zero_point,
            C_int32,
            C_uint8_base,
            row_offsets,
            col_offsets,
            bias,
            act_times_w_scale);
      }

      for (; w < w_end; ++w) {
        depthwise_2d_kernel_<
            S,
            FUSE_RELU,
            HAS_BIAS,
            A_SYMMETRIC,
            B_SYMMETRIC,
            Q_GRAN>(
            H,
            W,
            IC,
            OC,
            h,
            w,
            stride_h,
            stride_w,
            A_zero_point,
            A_base,
            B_zero_point,
            Bp,
            C_multiplier,
            C_zero_point,
            C_int32,
            C_uint8_base,
            row_offsets,
            col_offsets,
            bias,
            act_times_w_scale);
      }
    }

    // h <= H_OUT - PAD_B - stride_h
    // h <= (H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h
    // h_in <= -PAD_T +
    // ((H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h) * stride_h
    // Case 1) For stride_h == 1,
    // h_in <= -PAD_T + H + PAD_T + PAD_B - S + 1 - PAD_B - 1
    // h_in + S - H <= 0
    // Case 2) For stride_h == 2,
    // h_in <= -PAD_L +
    // H + PAD_T + PAD_B - S + 1 + (1 - PAD_B - stride_h) * stride_h
    // h_in + S - H <= PAD_B * (1 - stride_h) + 1 + (1 - stride_h) * stride_h
    //              <= -PAD_B + 1 - stride_h <= 0
    for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
      for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
        depthwise_2d_kernel_<
            S,
            FUSE_RELU,
            HAS_BIAS,
            A_SYMMETRIC,
            B_SYMMETRIC,
            Q_GRAN>(
            H,
            W,
            IC,
            OC,
            h,
            w,
            stride_h,
            stride_w,
            A_zero_point,
            A_base,
            B_zero_point,
            Bp,
            C_multiplier,
            C_zero_point,
            C_int32,
            C_uint8_base,
            row_offsets,
            col_offsets,
            bias,
            act_times_w_scale);
      }

      for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
        if (n == n_begin && w == std::max(PAD_L, w_begin)) {
          int remainder = OC % 32;
          if (remainder == 0) {
            remainder = 32;
          }
          middle_kernel = GenI8Depthwise().getOrCreate(
              /*D=*/2,
              {1, S, S},
              OC / IC,
              /*compute_a_sum=*/!B_SYMMETRIC,
              remainder,
              0,
              0,
              0,
              0,
              0,
              0);
        }
        depthwise_2d_kernel_<
            S,
            FUSE_RELU,
            HAS_BIAS,
            A_SYMMETRIC,
            B_SYMMETRIC,
            Q_GRAN>(
            H,
            W,
            IC,
            OC,
            h,
            w,
            stride_h,
            stride_w,
            A_zero_point,
            A_base,
            B_zero_point,
            Bp,
            C_multiplier,
            C_zero_point,
            C_int32,
            C_uint8_base,
            row_offsets,
            col_offsets,
            bias,
            act_times_w_scale,
            &middle_kernel);
      }

      for (; w < w_end; ++w) {
        depthwise_2d_kernel_<
            S,
            FUSE_RELU,
            HAS_BIAS,
            A_SYMMETRIC,
            B_SYMMETRIC,
            Q_GRAN>(
            H,
            W,
            IC,
            OC,
            h,
            w,
            stride_h,
            stride_w,
            A_zero_point,
            A_base,
            B_zero_point,
            Bp,
            C_multiplier,
            C_zero_point,
            C_int32,
            C_uint8_base,
            row_offsets,
            col_offsets,
            bias,
            act_times_w_scale);
      }
    }

    for (; h < h_end; ++h) {
      for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
        depthwise_2d_kernel_<
            S,
            FUSE_RELU,
            HAS_BIAS,
            A_SYMMETRIC,
            B_SYMMETRIC,
            Q_GRAN>(
            H,
            W,
            IC,
            OC,
            h,
            w,
            stride_h,
            stride_w,
            A_zero_point,
            A_base,
            B_zero_point,
            Bp,
            C_multiplier,
            C_zero_point,
            C_int32,
            C_uint8_base,
            row_offsets,
            col_offsets,
            bias,
            act_times_w_scale);
      }

      for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
        depthwise_2d_kernel_<
            S,
            FUSE_RELU,
            HAS_BIAS,
            A_SYMMETRIC,
            B_SYMMETRIC,
            Q_GRAN>(
            H,
            W,
            IC,
            OC,
            h,
            w,
            stride_h,
            stride_w,
            A_zero_point,
            A_base,
            B_zero_point,
            Bp,
            C_multiplier,
            C_zero_point,
            C_int32,
            C_uint8_base,
            row_offsets,
            col_offsets,
            bias,
            act_times_w_scale);
      }

      for (; w < w_end; ++w) {
        depthwise_2d_kernel_<
            S,
            FUSE_RELU,
            HAS_BIAS,
            A_SYMMETRIC,
            B_SYMMETRIC,
            Q_GRAN>(
            H,
            W,
            IC,
            OC,
            h,
            w,
            stride_h,
            stride_w,
            A_zero_point,
            A_base,
            B_zero_point,
            Bp,
            C_multiplier,
            C_zero_point,
            C_int32,
            C_uint8_base,
            row_offsets,
            col_offsets,
            bias,
            act_times_w_scale);
      }
    }
  } // for each n

  fbgemmAlignedFree(row_offsets);
}