in tensorflow/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h [6133:6534]
static inline void PackMacroBlockNeon(
int32 height_block_number, int32 width_block_number,
const uint8* input_block_data, int8* scratch_block_data,
const DepthwiseConvDotProdParams* function_params) {
constexpr uint8 kSignBit = 0x80;
const int workspace_height_stride =
function_params->workspace_height_stride;
const int width_overall_micro_repeats =
function_params->input_width_overall_micro_repeats;
const int input_width_micro_repeats =
function_params->input_width_micro_repeats;
const int depth_micro_repeats = function_params->depth_micro_repeats;
const int block_height = function_params->inbound_block_height;
const int residual_width = function_params->residual_width;
const int input_height_stride = function_params->input_height_stride;
const int input_depth = function_params->input_depth;
const int padding_left = function_params->padding_left;
const int padding_right = function_params->padding_right;
const int padding_top = function_params->padding_top;
const int padding_bottom = function_params->padding_bottom;
TFLITE_DCHECK_GT(depth_micro_repeats, 0);
constexpr int kSymmetricZeroPoint = 128;
const int micro_block_size = 4 * 8;
const int depth_advance = width_overall_micro_repeats * micro_block_size;
const int width_advance =
micro_block_size *
(1 - depth_micro_repeats * width_overall_micro_repeats);
const int height_advance = workspace_height_stride -
width_overall_micro_repeats * micro_block_size;
const int input_depth_skip = 4 * input_depth - 8 * depth_micro_repeats;
const bool leading_width_padding =
padding_left > 0 && width_block_number == 0;
const bool trailing_width_padding =
padding_right > 0 &&
width_block_number == (function_params->width_macro_count - 1);
const bool leading_height_padding =
padding_top > 0 && height_block_number < 0;
const bool trailing_height_padding =
padding_bottom > 0 &&
height_block_number == (function_params->height_macro_count - 1);
const int32 input_offset = function_params->input_offset;
const int32 input_offset_difference = input_offset + kSymmetricZeroPoint;
// Transpositions are 4x4, but doing 2 at a time is more efficient in NEON
// code. Note the blocks of 4x4 are still interleaved down the depth.
int8x16_t work_reg_a;
int8x16_t work_reg_b;
// Effect subtraction of zero-point = 128 by XOR of sign bit.
const int8x16_t sign_bit = vdupq_n_s8(kSignBit);
// Work through one slice, by row, at a time.
int8* scratch_data_0 = scratch_block_data;
int copy_block_height = block_height;
if (leading_height_padding) {
copy_block_height -= 1;
memset(scratch_data_0, -input_offset_difference, workspace_height_stride);
scratch_data_0 += workspace_height_stride;
input_block_data += input_height_stride;
}
if (trailing_height_padding) {
copy_block_height -= 1;
}
for (int k_height = 0; k_height < copy_block_height; ++k_height) {
const int8* input_data_0 =
reinterpret_cast<const int8*>(input_block_data);
int8x16_t input_data_a;
int8x16_t input_data_b;
int8x16_t input_data_c;
int8x16_t input_data_d;
// Traverse the width one point at a time, but the depth in (micro) blocks
// of size 8.
//
// The depth and width margins, which are filled with "zeros", may be
// larger than is strictly needed to calculate output. This is because the
// conv calculation is performed across complete micro blocks.
for (int j_width = 0; j_width < width_overall_micro_repeats; ++j_width) {
// Figure out division of work (available input vs zero-ed).
int adjusted_residual_width =
j_width == (input_width_micro_repeats) ? residual_width : 4;
if (trailing_width_padding &&
j_width == (width_overall_micro_repeats - 1)) {
adjusted_residual_width -= 1;
}
int start_width = 0;
if (leading_width_padding && j_width == 0) {
start_width = 1;
}
if (start_width == 0) {
if (adjusted_residual_width == 4) {
int8x16_t work_reg_a_sp;
int8x16_t work_reg_b_sp;
int i_depth = 0;
if (depth_micro_repeats >= 2) {
i_depth += 2;
//
input_data_a = vld1q_s8(input_data_0);
input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
input_data_0 += 16;
//
for (; i_depth < depth_micro_repeats - 1; i_depth += 2) {
work_reg_a = vzip1q_s8(input_data_a, input_data_b);
work_reg_b = vzip1q_s8(input_data_c, input_data_d);
vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
work_reg_a = veorq_s8(work_reg_a, sign_bit);
work_reg_b = veorq_s8(work_reg_b, sign_bit);
work_reg_a_sp = vzip2q_s8(input_data_a, input_data_b);
work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
input_data_a = vld1q_s8(input_data_0);
input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a);
vst1q_s8(scratch_data_0 + 16, work_reg_b);
scratch_data_0 += depth_advance;
work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);
input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a_sp);
vst1q_s8(scratch_data_0 + 16, work_reg_b_sp);
scratch_data_0 += depth_advance;
//
input_data_0 += 16;
}
work_reg_a = vzip1q_s8(input_data_a, input_data_b);
work_reg_b = vzip1q_s8(input_data_c, input_data_d);
vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
work_reg_a = veorq_s8(work_reg_a, sign_bit);
work_reg_b = veorq_s8(work_reg_b, sign_bit);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a);
vst1q_s8(scratch_data_0 + 16, work_reg_b);
scratch_data_0 += depth_advance;
//
work_reg_a_sp = vzip2q_s8(input_data_a, input_data_b);
work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a_sp);
vst1q_s8(scratch_data_0 + 16, work_reg_b_sp);
scratch_data_0 += depth_advance;
}
for (; i_depth < depth_micro_repeats; ++i_depth) {
input_data_a = vld1q_lane_s8x8(input_data_0, input_data_a, 0);
input_data_b = vld1q_lane_s8x8(input_data_0 + 1 * input_depth,
input_data_b, 0);
input_data_c = vld1q_lane_s8x8(input_data_0 + 2 * input_depth,
input_data_c, 0);
input_data_d = vld1q_lane_s8x8(input_data_0 + 3 * input_depth,
input_data_d, 0);
work_reg_a = vzip1q_s8(input_data_a, input_data_b);
work_reg_b = vzip1q_s8(input_data_c, input_data_d);
input_data_0 += 8;
vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
work_reg_a = veorq_s8(work_reg_a, sign_bit);
work_reg_b = veorq_s8(work_reg_b, sign_bit);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a);
vst1q_s8(scratch_data_0 + 16, work_reg_b);
scratch_data_0 += depth_advance;
}
scratch_data_0 += width_advance;
input_data_0 += input_depth_skip;
} else {
TFLITE_DCHECK_LT(adjusted_residual_width, 4);
for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) {
input_data_a = vdupq_n_s8(-input_offset);
input_data_b = vdupq_n_s8(-input_offset);
input_data_c = vdupq_n_s8(-input_offset);
input_data_d = vdupq_n_s8(-input_offset);
if (adjusted_residual_width > 0) {
input_data_a = vld1q_lane_s8x8(input_data_0, input_data_a, 0);
if (adjusted_residual_width > 1) {
input_data_b = vld1q_lane_s8x8(input_data_0 + input_depth,
input_data_b, 0);
if (adjusted_residual_width == 3) {
input_data_c = vld1q_lane_s8x8(
input_data_0 + 2 * input_depth, input_data_c, 0);
}
}
}
work_reg_a = vzip1q_s8(input_data_a, input_data_b);
work_reg_b = vzip1q_s8(input_data_c, input_data_d);
work_reg_a = veorq_s8(work_reg_a, sign_bit);
work_reg_b = veorq_s8(work_reg_b, sign_bit);
vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a);
vst1q_s8(scratch_data_0 + 16, work_reg_b);
scratch_data_0 += depth_advance;
input_data_0 += 8;
}
scratch_data_0 += width_advance;
input_data_0 += input_depth_skip;
}
} else {
if (adjusted_residual_width == 4) {
int8x16_t work_reg_a_sp;
int8x16_t work_reg_b_sp;
int i_depth = 0;
if (depth_micro_repeats >= 2) {
i_depth += 2;
//
input_data_a = vdupq_n_s8(-input_offset);
input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
input_data_0 += 16;
//
for (; i_depth < depth_micro_repeats - 1; i_depth += 2) {
work_reg_a = vzip1q_s8(input_data_a, input_data_b);
work_reg_b = vzip1q_s8(input_data_c, input_data_d);
vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
work_reg_a = veorq_s8(work_reg_a, sign_bit);
work_reg_b = veorq_s8(work_reg_b, sign_bit);
work_reg_a_sp = vzip2q_s8(input_data_a, input_data_b);
work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
input_data_a = vdupq_n_s8(-input_offset);
input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a);
vst1q_s8(scratch_data_0 + 16, work_reg_b);
scratch_data_0 += depth_advance;
work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);
input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a_sp);
vst1q_s8(scratch_data_0 + 16, work_reg_b_sp);
scratch_data_0 += depth_advance;
//
input_data_0 += 16;
}
work_reg_a = vzip1q_s8(input_data_a, input_data_b);
work_reg_b = vzip1q_s8(input_data_c, input_data_d);
vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
work_reg_a = veorq_s8(work_reg_a, sign_bit);
work_reg_b = veorq_s8(work_reg_b, sign_bit);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a);
vst1q_s8(scratch_data_0 + 16, work_reg_b);
scratch_data_0 += depth_advance;
//
work_reg_a_sp = vzip2q_s8(input_data_a, input_data_b);
work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a_sp);
vst1q_s8(scratch_data_0 + 16, work_reg_b_sp);
scratch_data_0 += depth_advance;
}
for (; i_depth < depth_micro_repeats; ++i_depth) {
input_data_a = vdupq_n_s8(-input_offset);
input_data_b = vld1q_lane_s8x8(input_data_0 + 1 * input_depth,
input_data_b, 0);
input_data_c = vld1q_lane_s8x8(input_data_0 + 2 * input_depth,
input_data_c, 0);
input_data_d = vld1q_lane_s8x8(input_data_0 + 3 * input_depth,
input_data_d, 0);
work_reg_a = vzip1q_s8(input_data_a, input_data_b);
work_reg_b = vzip1q_s8(input_data_c, input_data_d);
input_data_0 += 8;
vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
work_reg_a = veorq_s8(work_reg_a, sign_bit);
work_reg_b = veorq_s8(work_reg_b, sign_bit);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a);
vst1q_s8(scratch_data_0 + 16, work_reg_b);
scratch_data_0 += depth_advance;
}
scratch_data_0 += width_advance;
input_data_0 += input_depth_skip;
} else {
TFLITE_DCHECK_LT(adjusted_residual_width, 4);
for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) {
input_data_a = vdupq_n_s8(-input_offset);
input_data_b = vdupq_n_s8(-input_offset);
input_data_c = vdupq_n_s8(-input_offset);
input_data_d = vdupq_n_s8(-input_offset);
// Skip loading first column.
if (adjusted_residual_width > 1) {
input_data_b = vld1q_lane_s8x8(input_data_0 + input_depth,
input_data_b, 0);
if (adjusted_residual_width == 3) {
input_data_c = vld1q_lane_s8x8(input_data_0 + 2 * input_depth,
input_data_c, 0);
}
}
work_reg_a = vzip1q_s8(input_data_a, input_data_b);
work_reg_b = vzip1q_s8(input_data_c, input_data_d);
work_reg_a = veorq_s8(work_reg_a, sign_bit);
work_reg_b = veorq_s8(work_reg_b, sign_bit);
vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
optimized_ops_prefetch_write_l1_keep(scratch_data_0);
optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
vst1q_s8(scratch_data_0, work_reg_a);
vst1q_s8(scratch_data_0 + 16, work_reg_b);
scratch_data_0 += depth_advance;
input_data_0 += 8;
}
scratch_data_0 += width_advance;
input_data_0 += input_depth_skip;
}
}
}
scratch_data_0 += height_advance;
input_block_data += input_height_stride;
}
if (trailing_height_padding) {
memset(scratch_data_0, -input_offset_difference, workspace_height_stride);
scratch_data_0 += workspace_height_stride;
}
TFLITE_DCHECK_EQ(
scratch_data_0,
scratch_block_data + block_height * workspace_height_stride);
}