in src/FbgemmI8Depthwise2DAvx2-inl.h [124:489]
static ALWAYS_INLINE void depthwise_2d_(
int N,
int H,
int W,
int IC,
int OC,
int stride_h,
int stride_w,
std::int32_t A_zero_point,
const std::uint8_t* A,
const std::int32_t* B_zero_point,
const PackedDepthWiseConvMatrix& B,
const float* C_multiplier,
std::int32_t C_zero_point,
std::int32_t* C_int32,
std::uint8_t* C_uint8,
const std::int32_t* col_offsets,
const BIAS_TYPE* bias,
const float* act_times_w_scale,
int thread_id,
int num_threads) {
assert(IC % 8 == 0);
constexpr int R = S;
constexpr int64_t PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2,
PAD_R = PAD_L;
int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
const std::int8_t* Bp = B.PackedMat();
int32_t* row_offsets = static_cast<int32_t*>(
fbgemmAlignedAlloc(64, (IC + 31) / 32 * 32 * sizeof(int32_t)));
int64_t n_begin, n_end, h_begin, h_end, w_begin, w_end;
// Reuse the 3-dim partition scheme for parallelization in matrix
// multiplication.
thread_type_t th_info =
fbgemmGetThreadPartition(N, H_OUT, W_OUT, thread_id, num_threads);
// Calculate the begin and end index along the batch (N) dimension
fbgemmPartition1D(
th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end);
// Calculate the begin and end index along the H dimension
fbgemmPartition1D(
th_info.m_thread_id, th_info.m_num_threads, H_OUT, h_begin, h_end);
// Calculate the begin and end index along the W dimension
fbgemmPartition1D(
th_info.n_thread_id, th_info.n_num_threads, W_OUT, w_begin, w_end);
GenI8Depthwise::jit_kernel_signature middle_kernel;
for (int n = n_begin; n < n_end; ++n) {
const std::uint8_t* A_base = A + n * H * W * IC;
std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * OC;
int h = 0;
int w = 0;
for (h = h_begin; h < std::min(PAD_T, h_end); ++h) {
for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
H,
W,
IC,
OC,
h,
w,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
}
for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
H,
W,
IC,
OC,
h,
w,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
}
for (; w < w_end; ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
H,
W,
IC,
OC,
h,
w,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
}
}
// h <= H_OUT - PAD_B - stride_h
// h <= (H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h
// h_in <= -PAD_T +
// ((H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h) * stride_h
// Case 1) For stride_h == 1,
// h_in <= -PAD_T + H + PAD_T + PAD_B - S + 1 - PAD_B - 1
// h_in + S - H <= 0
// Case 2) For stride_h == 2,
// h_in <= -PAD_L +
// H + PAD_T + PAD_B - S + 1 + (1 - PAD_B - stride_h) * stride_h
// h_in + S - H <= PAD_B * (1 - stride_h) + 1 + (1 - stride_h) * stride_h
// <= -PAD_B + 1 - stride_h <= 0
for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
H,
W,
IC,
OC,
h,
w,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
}
for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
if (n == n_begin && w == std::max(PAD_L, w_begin)) {
int remainder = OC % 32;
if (remainder == 0) {
remainder = 32;
}
middle_kernel = GenI8Depthwise().getOrCreate(
/*D=*/2,
{1, S, S},
OC / IC,
/*compute_a_sum=*/!B_SYMMETRIC,
remainder,
0,
0,
0,
0,
0,
0);
}
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
H,
W,
IC,
OC,
h,
w,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale,
&middle_kernel);
}
for (; w < w_end; ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
H,
W,
IC,
OC,
h,
w,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
}
}
for (; h < h_end; ++h) {
for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
H,
W,
IC,
OC,
h,
w,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
}
for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
H,
W,
IC,
OC,
h,
w,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
}
for (; w < w_end; ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
H,
W,
IC,
OC,
h,
w,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
}
}
} // for each n
fbgemmAlignedFree(row_offsets);
}