static inline void PackMacroBlockNeon()

in tensorflow/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h [6133:6534]


  static inline void PackMacroBlockNeon(
      int32 height_block_number, int32 width_block_number,
      const uint8* input_block_data, int8* scratch_block_data,
      const DepthwiseConvDotProdParams* function_params) {
    constexpr uint8 kSignBit = 0x80;

    const int workspace_height_stride =
        function_params->workspace_height_stride;
    const int width_overall_micro_repeats =
        function_params->input_width_overall_micro_repeats;
    const int input_width_micro_repeats =
        function_params->input_width_micro_repeats;
    const int depth_micro_repeats = function_params->depth_micro_repeats;
    const int block_height = function_params->inbound_block_height;
    const int residual_width = function_params->residual_width;
    const int input_height_stride = function_params->input_height_stride;
    const int input_depth = function_params->input_depth;

    const int padding_left = function_params->padding_left;
    const int padding_right = function_params->padding_right;
    const int padding_top = function_params->padding_top;
    const int padding_bottom = function_params->padding_bottom;

    TFLITE_DCHECK_GT(depth_micro_repeats, 0);
    constexpr int kSymmetricZeroPoint = 128;

    const int micro_block_size = 4 * 8;
    const int depth_advance = width_overall_micro_repeats * micro_block_size;
    const int width_advance =
        micro_block_size *
        (1 - depth_micro_repeats * width_overall_micro_repeats);
    const int height_advance = workspace_height_stride -
                               width_overall_micro_repeats * micro_block_size;
    const int input_depth_skip = 4 * input_depth - 8 * depth_micro_repeats;

    const bool leading_width_padding =
        padding_left > 0 && width_block_number == 0;
    const bool trailing_width_padding =
        padding_right > 0 &&
        width_block_number == (function_params->width_macro_count - 1);
    const bool leading_height_padding =
        padding_top > 0 && height_block_number < 0;
    const bool trailing_height_padding =
        padding_bottom > 0 &&
        height_block_number == (function_params->height_macro_count - 1);

    const int32 input_offset = function_params->input_offset;
    const int32 input_offset_difference = input_offset + kSymmetricZeroPoint;

    // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON
    // code. Note the blocks of 4x4 are still interleaved down the depth.
    int8x16_t work_reg_a;
    int8x16_t work_reg_b;

    // Effect subtraction of zero-point = 128 by XOR of sign bit.
    const int8x16_t sign_bit = vdupq_n_s8(kSignBit);

    // Work through one slice, by row, at a time.
    int8* scratch_data_0 = scratch_block_data;

    int copy_block_height = block_height;
    if (leading_height_padding) {
      copy_block_height -= 1;
      memset(scratch_data_0, -input_offset_difference, workspace_height_stride);
      scratch_data_0 += workspace_height_stride;
      input_block_data += input_height_stride;
    }
    if (trailing_height_padding) {
      copy_block_height -= 1;
    }

    for (int k_height = 0; k_height < copy_block_height; ++k_height) {
      const int8* input_data_0 =
          reinterpret_cast<const int8*>(input_block_data);
      int8x16_t input_data_a;
      int8x16_t input_data_b;
      int8x16_t input_data_c;
      int8x16_t input_data_d;

      // Traverse the width one point at a time, but the depth in (micro) blocks
      // of size 8.
      //
      // The depth and width margins, which are filled with "zeros", may be
      // larger than is strictly needed to calculate output. This is because the
      // conv calculation is performed across complete micro blocks.
      for (int j_width = 0; j_width < width_overall_micro_repeats; ++j_width) {
        // Figure out division of work (available input vs zero-ed).
        int adjusted_residual_width =
            j_width == (input_width_micro_repeats) ? residual_width : 4;

        if (trailing_width_padding &&
            j_width == (width_overall_micro_repeats - 1)) {
          adjusted_residual_width -= 1;
        }
        int start_width = 0;
        if (leading_width_padding && j_width == 0) {
          start_width = 1;
        }
        if (start_width == 0) {
          if (adjusted_residual_width == 4) {
            int8x16_t work_reg_a_sp;
            int8x16_t work_reg_b_sp;

            int i_depth = 0;

            if (depth_micro_repeats >= 2) {
              i_depth += 2;

              //

              input_data_a = vld1q_s8(input_data_0);
              input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
              input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
              input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
              input_data_0 += 16;

              //

              for (; i_depth < depth_micro_repeats - 1; i_depth += 2) {
                work_reg_a = vzip1q_s8(input_data_a, input_data_b);
                work_reg_b = vzip1q_s8(input_data_c, input_data_d);
                vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
                work_reg_a = veorq_s8(work_reg_a, sign_bit);
                work_reg_b = veorq_s8(work_reg_b, sign_bit);

                work_reg_a_sp = vzip2q_s8(input_data_a, input_data_b);
                work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
                vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);

                input_data_a = vld1q_s8(input_data_0);
                input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
                optimized_ops_prefetch_write_l1_keep(scratch_data_0);
                optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
                vst1q_s8(scratch_data_0, work_reg_a);
                vst1q_s8(scratch_data_0 + 16, work_reg_b);

                scratch_data_0 += depth_advance;

                work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
                work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);

                input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
                input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
                optimized_ops_prefetch_write_l1_keep(scratch_data_0);
                optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
                vst1q_s8(scratch_data_0, work_reg_a_sp);
                vst1q_s8(scratch_data_0 + 16, work_reg_b_sp);

                scratch_data_0 += depth_advance;

                //

                input_data_0 += 16;
              }

              work_reg_a = vzip1q_s8(input_data_a, input_data_b);
              work_reg_b = vzip1q_s8(input_data_c, input_data_d);
              vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
              work_reg_a = veorq_s8(work_reg_a, sign_bit);
              work_reg_b = veorq_s8(work_reg_b, sign_bit);
              optimized_ops_prefetch_write_l1_keep(scratch_data_0);
              optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
              vst1q_s8(scratch_data_0, work_reg_a);
              vst1q_s8(scratch_data_0 + 16, work_reg_b);

              scratch_data_0 += depth_advance;
              //

              work_reg_a_sp = vzip2q_s8(input_data_a, input_data_b);
              work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
              vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
              work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
              work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);

              optimized_ops_prefetch_write_l1_keep(scratch_data_0);
              optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
              vst1q_s8(scratch_data_0, work_reg_a_sp);
              vst1q_s8(scratch_data_0 + 16, work_reg_b_sp);

              scratch_data_0 += depth_advance;
            }
            for (; i_depth < depth_micro_repeats; ++i_depth) {
              input_data_a = vld1q_lane_s8x8(input_data_0, input_data_a, 0);
              input_data_b = vld1q_lane_s8x8(input_data_0 + 1 * input_depth,
                                             input_data_b, 0);
              input_data_c = vld1q_lane_s8x8(input_data_0 + 2 * input_depth,
                                             input_data_c, 0);
              input_data_d = vld1q_lane_s8x8(input_data_0 + 3 * input_depth,
                                             input_data_d, 0);
              work_reg_a = vzip1q_s8(input_data_a, input_data_b);
              work_reg_b = vzip1q_s8(input_data_c, input_data_d);

              input_data_0 += 8;

              vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
              work_reg_a = veorq_s8(work_reg_a, sign_bit);
              work_reg_b = veorq_s8(work_reg_b, sign_bit);

              optimized_ops_prefetch_write_l1_keep(scratch_data_0);
              optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
              vst1q_s8(scratch_data_0, work_reg_a);
              vst1q_s8(scratch_data_0 + 16, work_reg_b);

              scratch_data_0 += depth_advance;
            }
            scratch_data_0 += width_advance;
            input_data_0 += input_depth_skip;
          } else {
            TFLITE_DCHECK_LT(adjusted_residual_width, 4);
            for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) {
              input_data_a = vdupq_n_s8(-input_offset);
              input_data_b = vdupq_n_s8(-input_offset);
              input_data_c = vdupq_n_s8(-input_offset);
              input_data_d = vdupq_n_s8(-input_offset);
              if (adjusted_residual_width > 0) {
                input_data_a = vld1q_lane_s8x8(input_data_0, input_data_a, 0);
                if (adjusted_residual_width > 1) {
                  input_data_b = vld1q_lane_s8x8(input_data_0 + input_depth,
                                                 input_data_b, 0);
                  if (adjusted_residual_width == 3) {
                    input_data_c = vld1q_lane_s8x8(
                        input_data_0 + 2 * input_depth, input_data_c, 0);
                  }
                }
              }
              work_reg_a = vzip1q_s8(input_data_a, input_data_b);
              work_reg_b = vzip1q_s8(input_data_c, input_data_d);

              work_reg_a = veorq_s8(work_reg_a, sign_bit);
              work_reg_b = veorq_s8(work_reg_b, sign_bit);
              vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);

              optimized_ops_prefetch_write_l1_keep(scratch_data_0);
              optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
              vst1q_s8(scratch_data_0, work_reg_a);
              vst1q_s8(scratch_data_0 + 16, work_reg_b);

              scratch_data_0 += depth_advance;
              input_data_0 += 8;
            }
            scratch_data_0 += width_advance;
            input_data_0 += input_depth_skip;
          }
        } else {
          if (adjusted_residual_width == 4) {
            int8x16_t work_reg_a_sp;
            int8x16_t work_reg_b_sp;

            int i_depth = 0;

            if (depth_micro_repeats >= 2) {
              i_depth += 2;

              //

              input_data_a = vdupq_n_s8(-input_offset);
              input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
              input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
              input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
              input_data_0 += 16;

              //

              for (; i_depth < depth_micro_repeats - 1; i_depth += 2) {
                work_reg_a = vzip1q_s8(input_data_a, input_data_b);
                work_reg_b = vzip1q_s8(input_data_c, input_data_d);
                vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
                work_reg_a = veorq_s8(work_reg_a, sign_bit);
                work_reg_b = veorq_s8(work_reg_b, sign_bit);

                work_reg_a_sp = vzip2q_s8(input_data_a, input_data_b);
                work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
                vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);

                input_data_a = vdupq_n_s8(-input_offset);
                input_data_b = vld1q_s8(input_data_0 + 1 * input_depth);
                optimized_ops_prefetch_write_l1_keep(scratch_data_0);
                optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
                vst1q_s8(scratch_data_0, work_reg_a);
                vst1q_s8(scratch_data_0 + 16, work_reg_b);

                scratch_data_0 += depth_advance;

                work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
                work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);

                input_data_c = vld1q_s8(input_data_0 + 2 * input_depth);
                input_data_d = vld1q_s8(input_data_0 + 3 * input_depth);
                optimized_ops_prefetch_write_l1_keep(scratch_data_0);
                optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
                vst1q_s8(scratch_data_0, work_reg_a_sp);
                vst1q_s8(scratch_data_0 + 16, work_reg_b_sp);

                scratch_data_0 += depth_advance;

                //

                input_data_0 += 16;
              }

              work_reg_a = vzip1q_s8(input_data_a, input_data_b);
              work_reg_b = vzip1q_s8(input_data_c, input_data_d);
              vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
              work_reg_a = veorq_s8(work_reg_a, sign_bit);
              work_reg_b = veorq_s8(work_reg_b, sign_bit);
              optimized_ops_prefetch_write_l1_keep(scratch_data_0);
              optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
              vst1q_s8(scratch_data_0, work_reg_a);
              vst1q_s8(scratch_data_0 + 16, work_reg_b);

              scratch_data_0 += depth_advance;
              //

              work_reg_a_sp = vzip2q_s8(input_data_a, input_data_b);
              work_reg_b_sp = vzip2q_s8(input_data_c, input_data_d);
              vzipq_s8x2_in_place(&work_reg_a_sp, &work_reg_b_sp);
              work_reg_a_sp = veorq_s8(work_reg_a_sp, sign_bit);
              work_reg_b_sp = veorq_s8(work_reg_b_sp, sign_bit);

              optimized_ops_prefetch_write_l1_keep(scratch_data_0);
              optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
              vst1q_s8(scratch_data_0, work_reg_a_sp);
              vst1q_s8(scratch_data_0 + 16, work_reg_b_sp);

              scratch_data_0 += depth_advance;
            }
            for (; i_depth < depth_micro_repeats; ++i_depth) {
              input_data_a = vdupq_n_s8(-input_offset);
              input_data_b = vld1q_lane_s8x8(input_data_0 + 1 * input_depth,
                                             input_data_b, 0);
              input_data_c = vld1q_lane_s8x8(input_data_0 + 2 * input_depth,
                                             input_data_c, 0);
              input_data_d = vld1q_lane_s8x8(input_data_0 + 3 * input_depth,
                                             input_data_d, 0);
              work_reg_a = vzip1q_s8(input_data_a, input_data_b);
              work_reg_b = vzip1q_s8(input_data_c, input_data_d);

              input_data_0 += 8;

              vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);
              work_reg_a = veorq_s8(work_reg_a, sign_bit);
              work_reg_b = veorq_s8(work_reg_b, sign_bit);

              optimized_ops_prefetch_write_l1_keep(scratch_data_0);
              optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
              vst1q_s8(scratch_data_0, work_reg_a);
              vst1q_s8(scratch_data_0 + 16, work_reg_b);

              scratch_data_0 += depth_advance;
            }
            scratch_data_0 += width_advance;
            input_data_0 += input_depth_skip;
          } else {
            TFLITE_DCHECK_LT(adjusted_residual_width, 4);

            for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) {
              input_data_a = vdupq_n_s8(-input_offset);
              input_data_b = vdupq_n_s8(-input_offset);
              input_data_c = vdupq_n_s8(-input_offset);
              input_data_d = vdupq_n_s8(-input_offset);
              // Skip loading first column.
              if (adjusted_residual_width > 1) {
                input_data_b = vld1q_lane_s8x8(input_data_0 + input_depth,
                                               input_data_b, 0);
                if (adjusted_residual_width == 3) {
                  input_data_c = vld1q_lane_s8x8(input_data_0 + 2 * input_depth,
                                                 input_data_c, 0);
                }
              }
              work_reg_a = vzip1q_s8(input_data_a, input_data_b);
              work_reg_b = vzip1q_s8(input_data_c, input_data_d);

              work_reg_a = veorq_s8(work_reg_a, sign_bit);
              work_reg_b = veorq_s8(work_reg_b, sign_bit);
              vzipq_s8x2_in_place(&work_reg_a, &work_reg_b);

              optimized_ops_prefetch_write_l1_keep(scratch_data_0);
              optimized_ops_prefetch_write_l1_keep(scratch_data_0 + 16);
              vst1q_s8(scratch_data_0, work_reg_a);
              vst1q_s8(scratch_data_0 + 16, work_reg_b);

              scratch_data_0 += depth_advance;
              input_data_0 += 8;
            }
            scratch_data_0 += width_advance;
            input_data_0 += input_depth_skip;
          }
        }
      }
      scratch_data_0 += height_advance;
      input_block_data += input_height_stride;
    }

    if (trailing_height_padding) {
      memset(scratch_data_0, -input_offset_difference, workspace_height_stride);
      scratch_data_0 += workspace_height_stride;
    }

    TFLITE_DCHECK_EQ(
        scratch_data_0,
        scratch_block_data + block_height * workspace_height_stride);
  }