in src/FbgemmI8Depthwise3DAvx2.cc [132:780]
static ALWAYS_INLINE void depthwise_3d_same_pad_(
const conv_param_t<3>& conv_p,
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
const PackedDepthWiseConvMatrix& B,
const float* C_multiplier,
int32_t C_zero_point,
int32_t* C_int32,
uint8_t* C_uint8,
const int32_t* col_offsets,
const BIAS_TYPE* bias,
const float* act_times_w_scale,
int thread_id,
int num_threads) {
int N = conv_p.MB;
int T = conv_p.IN_DIM[0];
int H = conv_p.IN_DIM[1];
int W = conv_p.IN_DIM[2];
int IC = conv_p.IC;
int OC = conv_p.OC;
array<int, 3> F = conv_p.K;
int stride_t = conv_p.stride[0];
int stride_h = conv_p.stride[1];
int stride_w = conv_p.stride[2];
assert(IC % 8 == 0);
int K_T = F[0], K_H = F[1], K_W = F[2];
int PAD_P = (F[0] - 1) / 2, PAD_N = PAD_P, PAD_T = (F[1] - 1) / 2,
PAD_B = PAD_T, PAD_L = (F[2] - 1) / 2, PAD_R = PAD_L;
int64_t T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
int64_t H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
const int8_t* Bp = B.PackedMat();
int32_t* row_offsets = static_cast<int32_t*>(
fbgemmAlignedAlloc(64, (IC + 31) / 32 * 32 * sizeof(int32_t)));
int64_t n_begin, n_end, t_begin, t_end, h_begin, h_end;
// Reuse the 3-dim partition scheme for parallelization in matrix
// multiplication.
thread_type_t th_info =
fbgemmGetThreadPartition(N, T_OUT, H_OUT, thread_id, num_threads);
// Calculate the begin and end index along the batch (N) dimension
fbgemmPartition1D(
th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end);
// Calculate the begin and end index along the T dimension
fbgemmPartition1D(
th_info.m_thread_id, th_info.m_num_threads, T_OUT, t_begin, t_end);
// Calculate the begin and end index along the H dimension
fbgemmPartition1D(
th_info.n_thread_id, th_info.n_num_threads, H_OUT, h_begin, h_end);
GenI8Depthwise::jit_kernel_signature middle_kernel;
for (int n = n_begin; n < n_end; ++n) {
const uint8_t* A_base = A + n * T * H * W * IC;
uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * OC;
int t;
for (t = t_begin; t < PAD_P; ++t) {
int h;
for (h = h_begin; h < PAD_T; ++h) {
for (int w = 0; w < W_OUT; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
} // w
} // h
for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
int w;
for (w = 0; w < PAD_L; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
} // w
GenI8Depthwise::jit_kernel_signature kernel;
for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
if (w == PAD_L) {
int remainder = OC % 32;
if (remainder == 0) {
remainder = 32;
}
int t_in = -PAD_P + t * stride_t;
kernel = GenI8Depthwise().getOrCreate(
/*D=*/3,
F,
OC / IC,
/*compute_a_sum=*/!B_SYMMETRIC,
remainder,
/*prev_skip=*/std::max(-t_in, 0),
/*next_skip=*/std::max(t_in + F[0] - T, 0),
0,
0,
0,
0);
}
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale,
&kernel);
} // w
for (; w < W_OUT; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
} // w
} // h
for (; h < h_end; ++h) {
for (int w = 0; w < W_OUT; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
} // w
} // h
} // t
for (; t < std::min(T_OUT - PAD_N - stride_t + 1, t_end); ++t) {
int h;
for (h = h_begin; h < PAD_T; ++h) {
for (int w = 0; w < W_OUT; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
} // w
} // h
for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
int w;
for (w = 0; w < PAD_L; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
} // w
for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
if (n == n_begin && w == PAD_L) {
int remainder = OC % 32;
if (remainder == 0) {
remainder = 32;
}
middle_kernel = GenI8Depthwise().getOrCreate(
/*D=*/3,
F,
OC / IC,
/*compute_a_sum=*/!B_SYMMETRIC,
remainder,
0,
0,
0,
0,
0,
0);
}
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale,
&middle_kernel);
}
for (; w < W_OUT; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
}
} // h
for (; h < h_end; ++h) {
for (int w = 0; w < W_OUT; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
} // w
} // h
} // t
for (; t < t_end; ++t) {
int h;
for (h = h_begin; h < PAD_T; ++h) {
for (int w = 0; w < W_OUT; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
} // w
} // h
for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
int w;
for (w = 0; w < PAD_L; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
} // w
GenI8Depthwise::jit_kernel_signature kernel;
for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
if (w == PAD_L) {
int remainder = OC % 32;
if (remainder == 0) {
remainder = 32;
}
int t_in = -PAD_P + t * stride_t;
kernel = GenI8Depthwise().getOrCreate(
/*D=*/3,
F,
OC / IC,
/*compute_a_sum=*/!B_SYMMETRIC,
remainder,
/*prev_skip=*/std::max(-t_in, 0),
/*next_skip=*/std::max(t_in + F[0] - T, 0),
0,
0,
0,
0);
}
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale,
&kernel);
} // w
for (; w < W_OUT; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
} // w
} // h
for (; h < h_end; ++h) {
for (int w = 0; w < W_OUT; ++w) {
depthwise_3d_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
Q_GRAN>(
T,
H,
W,
IC,
OC,
t,
h,
w,
F,
stride_t,
stride_h,
stride_w,
A_zero_point,
A_base,
B_zero_point,
Bp,
C_multiplier,
C_zero_point,
C_int32,
C_uint8_base,
row_offsets,
col_offsets,
bias,
act_times_w_scale);
} // w
} // h
} // t
} // for each n
fbgemmAlignedFree(row_offsets);
}