static ALWAYS_INLINE void depthwise_3d_same_pad_()

in src/FbgemmI8Depthwise3DAvx2.cc [132:780]


static ALWAYS_INLINE void depthwise_3d_same_pad_(
    const conv_param_t<3>& conv_p,
    int32_t A_zero_point,
    const uint8_t* A,
    const int32_t* B_zero_point,
    const PackedDepthWiseConvMatrix& B,
    const float* C_multiplier,
    int32_t C_zero_point,
    int32_t* C_int32,
    uint8_t* C_uint8,
    const int32_t* col_offsets,
    const BIAS_TYPE* bias,
    const float* act_times_w_scale,
    int thread_id,
    int num_threads) {
  int N = conv_p.MB;
  int T = conv_p.IN_DIM[0];
  int H = conv_p.IN_DIM[1];
  int W = conv_p.IN_DIM[2];
  int IC = conv_p.IC;
  int OC = conv_p.OC;
  array<int, 3> F = conv_p.K;
  int stride_t = conv_p.stride[0];
  int stride_h = conv_p.stride[1];
  int stride_w = conv_p.stride[2];

  assert(IC % 8 == 0);

  int K_T = F[0], K_H = F[1], K_W = F[2];
  int PAD_P = (F[0] - 1) / 2, PAD_N = PAD_P, PAD_T = (F[1] - 1) / 2,
      PAD_B = PAD_T, PAD_L = (F[2] - 1) / 2, PAD_R = PAD_L;
  int64_t T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
  int64_t H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
  int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
  const int8_t* Bp = B.PackedMat();

  int32_t* row_offsets = static_cast<int32_t*>(
      fbgemmAlignedAlloc(64, (IC + 31) / 32 * 32 * sizeof(int32_t)));

  int64_t n_begin, n_end, t_begin, t_end, h_begin, h_end;
  // Reuse the 3-dim partition scheme for parallelization in matrix
  // multiplication.
  thread_type_t th_info =
      fbgemmGetThreadPartition(N, T_OUT, H_OUT, thread_id, num_threads);
  // Calculate the begin and end index along the batch (N) dimension
  fbgemmPartition1D(
      th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end);
  // Calculate the begin and end index along the T dimension
  fbgemmPartition1D(
      th_info.m_thread_id, th_info.m_num_threads, T_OUT, t_begin, t_end);
  // Calculate the begin and end index along the H dimension
  fbgemmPartition1D(
      th_info.n_thread_id, th_info.n_num_threads, H_OUT, h_begin, h_end);

  GenI8Depthwise::jit_kernel_signature middle_kernel;

  for (int n = n_begin; n < n_end; ++n) {
    const uint8_t* A_base = A + n * T * H * W * IC;
    uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * OC;

    int t;
    for (t = t_begin; t < PAD_P; ++t) {
      int h;
      for (h = h_begin; h < PAD_T; ++h) {
        for (int w = 0; w < W_OUT; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        } // w
      } // h

      for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
        int w;
        for (w = 0; w < PAD_L; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        } // w

        GenI8Depthwise::jit_kernel_signature kernel;
        for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
          if (w == PAD_L) {
            int remainder = OC % 32;
            if (remainder == 0) {
              remainder = 32;
            }
            int t_in = -PAD_P + t * stride_t;
            kernel = GenI8Depthwise().getOrCreate(
                /*D=*/3,
                F,
                OC / IC,
                /*compute_a_sum=*/!B_SYMMETRIC,
                remainder,
                /*prev_skip=*/std::max(-t_in, 0),
                /*next_skip=*/std::max(t_in + F[0] - T, 0),
                0,
                0,
                0,
                0);
          }
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale,
              &kernel);
        } // w

        for (; w < W_OUT; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        } // w
      } // h

      for (; h < h_end; ++h) {
        for (int w = 0; w < W_OUT; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        } // w
      } // h
    } // t

    for (; t < std::min(T_OUT - PAD_N - stride_t + 1, t_end); ++t) {
      int h;
      for (h = h_begin; h < PAD_T; ++h) {
        for (int w = 0; w < W_OUT; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        } // w
      } // h

      for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
        int w;
        for (w = 0; w < PAD_L; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        } // w

        for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
          if (n == n_begin && w == PAD_L) {
            int remainder = OC % 32;
            if (remainder == 0) {
              remainder = 32;
            }
            middle_kernel = GenI8Depthwise().getOrCreate(
                /*D=*/3,
                F,
                OC / IC,
                /*compute_a_sum=*/!B_SYMMETRIC,
                remainder,
                0,
                0,
                0,
                0,
                0,
                0);
          }
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale,
              &middle_kernel);
        }

        for (; w < W_OUT; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        }
      } // h

      for (; h < h_end; ++h) {
        for (int w = 0; w < W_OUT; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        } // w
      } // h
    } // t

    for (; t < t_end; ++t) {
      int h;
      for (h = h_begin; h < PAD_T; ++h) {
        for (int w = 0; w < W_OUT; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        } // w
      } // h

      for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
        int w;
        for (w = 0; w < PAD_L; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        } // w

        GenI8Depthwise::jit_kernel_signature kernel;
        for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
          if (w == PAD_L) {
            int remainder = OC % 32;
            if (remainder == 0) {
              remainder = 32;
            }
            int t_in = -PAD_P + t * stride_t;
            kernel = GenI8Depthwise().getOrCreate(
                /*D=*/3,
                F,
                OC / IC,
                /*compute_a_sum=*/!B_SYMMETRIC,
                remainder,
                /*prev_skip=*/std::max(-t_in, 0),
                /*next_skip=*/std::max(t_in + F[0] - T, 0),
                0,
                0,
                0,
                0);
          }
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale,
              &kernel);
        } // w

        for (; w < W_OUT; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        } // w
      } // h

      for (; h < h_end; ++h) {
        for (int w = 0; w < W_OUT; ++w) {
          depthwise_3d_kernel_<
              FUSE_RELU,
              HAS_BIAS,
              A_SYMMETRIC,
              B_SYMMETRIC,
              Q_GRAN>(
              T,
              H,
              W,
              IC,
              OC,
              t,
              h,
              w,
              F,
              stride_t,
              stride_h,
              stride_w,
              A_zero_point,
              A_base,
              B_zero_point,
              Bp,
              C_multiplier,
              C_zero_point,
              C_int32,
              C_uint8_base,
              row_offsets,
              col_offsets,
              bias,
              act_times_w_scale);
        } // w
      } // h
    } // t
  } // for each n
  fbgemmAlignedFree(row_offsets);
}