static inline void KernelMacroBlockIntrinsics()

in tensorflow/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h [3731:4143]


  static inline void KernelMacroBlockIntrinsics(
      const int8* scratch_block_data, const int8* filter_workspace,
      const int32* bias_data, uint8* output_block_data,
      const DepthwiseConvDotProdParams* function_params) {
    const int workspace_height_stride =
        function_params->workspace_height_stride;
    const int input_width_overall_micro_repeats =
        function_params->input_width_overall_micro_repeats;
    const int output_width_micro_repeats =
        function_params->output_width_micro_repeats;
    const int depth_micro_repeats = function_params->depth_micro_repeats;
    const int depth = function_params->input_depth;
    constexpr int kStrideVal = 2;
    constexpr int kFourOverStride = 2;
    TFLITE_DCHECK_EQ(function_params->stride, kStrideVal);
    TFLITE_DCHECK_EQ(function_params->four_over_stride, kFourOverStride);

    const int workspace_width_micro_repeats =
        function_params->workspace_width_micro_repeats;
    const int output_width_overall_micro_repeats =
        function_params->output_width_overall_micro_repeats;
    const int block_height = function_params->outbound_block_height;
    const int residual_width = function_params->output_residual_width;
    const int output_height_stride = function_params->output_height_stride;
    constexpr int kBiasIncrement = 4;

    TFLITE_DCHECK(depth_micro_repeats > 0);
    const int width_micro_stride = 4 * 8;
    const int depth_micro_stride =
        width_micro_stride * input_width_overall_micro_repeats;

    const int32 output_activation_min =
        function_params->quantized_activation_min;
    const int32 output_activation_max =
        function_params->quantized_activation_max;
    const int32 output_multiplier = function_params->output_multiplier;
    const int32 output_shift = function_params->output_shift;
    const int32 output_offset = function_params->output_offset;
    TFLITE_DCHECK_GE(output_activation_min, 0);
    TFLITE_DCHECK_LT(output_activation_min, 256);
    TFLITE_DCHECK_GE(output_activation_max, 0);
    TFLITE_DCHECK_LT(output_activation_max, 256);
    TFLITE_DCHECK_GE(output_offset, -32878);
    TFLITE_DCHECK_LT(output_offset, 32768);

    // This version only does min/max on 64 bits.
    const int16x8_t output_offset_vec =
        vdupq_n_s16(static_cast<int16>(output_offset));
    const uint8x8_t output_activation_min_vec =
        vdup_n_u8(static_cast<uint8>(output_activation_min));
    const uint8x8_t output_activation_max_vec =
        vdup_n_u8(static_cast<uint8>(output_activation_max));

    constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4;

    TFLITE_DCHECK_LE(block_height, 2);

    for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) {
      const int8* filter_block =
          filter_workspace + shuffled_filter_increment * j_depth;

      if (block_height == 2) {
        for (int s = 0; s < 2; ++s) {
          // Simulate NEON-register transposition of subset of filter.
          int8x16_t filter_reg_0_a;
          int8x16_t filter_reg_1_a;
          int8x16_t filter_reg_2_a;

          filter_reg_0_a = vld1q_s8(filter_block + s * 16);
          filter_reg_1_a = vld1q_s8(filter_block + s * 16 + 32);
          filter_reg_2_a = vld1q_s8(filter_block + s * 16 + 64);

          const int8* scratch_data =
              scratch_block_data + depth_micro_stride * j_depth;
          uint8* output_data = output_block_data + 8 * j_depth;
          const int8* input_data_0 = scratch_data + s * 2 * 8;

          const int32x4_t adjusted_bias_data = vld1q_s32(bias_data);

          // Load first sub-micro block of data into operational banks.
          int8x16_t left_bank_0_reg = vld1q_s8(input_data_0);
          int8x16_t left_bank_1_reg =
              vld1q_s8(input_data_0 + workspace_height_stride);
          int8x16_t left_bank_2_reg =
              vld1q_s8(input_data_0 + 2 * workspace_height_stride);
          int8x16_t left_bank_3_reg =
              vld1q_s8(input_data_0 + 3 * workspace_height_stride);
          int8x16_t left_bank_4_reg =
              vld1q_s8(input_data_0 + 4 * workspace_height_stride);

          int8x16_t right_bank_0_reg;
          int8x16_t right_bank_1_reg;
          int8x16_t right_bank_2_reg;
          int8x16_t right_bank_3_reg;
          int8x16_t right_bank_4_reg;

          int32x4_t acc0;
          int32x4_t acc1;
          int16x8_t acc_s16_0_1;
          uint8x8_t acc_u8;

          int i_width = 0;

          // When output_width_micro_repeats <
          // output_width_overall_micro_repeats, 0 < residual_width <= 2, and so
          // residual_width == 1 is then true iff residual_width < 2.
          const int adjusted_width_micro_repeats =
              (output_width_micro_repeats <
               output_width_overall_micro_repeats) &&
                      (residual_width == 1)
                  ? output_width_micro_repeats
                  : output_width_overall_micro_repeats;

          for (; i_width < adjusted_width_micro_repeats; ++i_width) {
            const int output_width = kFourOverStride;
            TFLITE_DCHECK_LE(output_width * kStrideVal, 4);
            const int8* input_data =
                input_data_0 + width_micro_stride * i_width;
            acc0 = adjusted_bias_data;
            acc1 = adjusted_bias_data;
            right_bank_0_reg = vld1q_s8(input_data + width_micro_stride);
            right_bank_1_reg = vld1q_s8(input_data + width_micro_stride +
                                        workspace_height_stride);

            acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg);
            acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_2_reg);
            uint8* output_data_base = output_data + depth * 2 * i_width + 4 * s;

            right_bank_2_reg = vld1q_s8(input_data + width_micro_stride +
                                        2 * workspace_height_stride);
            right_bank_3_reg = vld1q_s8(input_data + width_micro_stride +
                                        3 * workspace_height_stride);
            acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg);
            acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg);
            acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_3_reg);
            acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_4_reg);
            right_bank_4_reg = vld1q_s8(input_data + width_micro_stride +
                                        4 * workspace_height_stride);

            // Fixed-point multiplication.
            acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
            acc0 = DivideByPOT<DepthwiseConvOutputRounding::kUpward>::Run(
                acc0, -output_shift);
            acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
            acc1 = DivideByPOT<DepthwiseConvOutputRounding::kUpward>::Run(
                acc1, -output_shift);
            // Add the output offset.
            acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1));
            acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec);
            // Apply the activation function.
            acc_u8 = vqmovun_s16(acc_s16_0_1);
            acc_u8 = vmax_u8(acc_u8, output_activation_min_vec);
            acc_u8 = vmin_u8(acc_u8, output_activation_max_vec);

            left_bank_0_reg = vrev32q_u16(left_bank_0_reg);
            left_bank_1_reg = vrev32q_u16(left_bank_1_reg);
            left_bank_2_reg = vrev32q_u16(left_bank_2_reg);
            left_bank_3_reg = vrev32q_u16(left_bank_3_reg);
            left_bank_4_reg = vrev32q_u16(left_bank_4_reg);
            acc0 = adjusted_bias_data;
            acc1 = adjusted_bias_data;
            vtrn1_s8x2_in_place(&left_bank_0_reg, &right_bank_0_reg);
            vtrn1_s8x2_in_place(&left_bank_1_reg, &right_bank_1_reg);
            vtrn1_s8x2_in_place(&left_bank_2_reg, &right_bank_2_reg);
            vst1_lane_8x4(output_data_base, acc_u8, 0);
            vst1_lane_8x4(output_data_base + output_height_stride, acc_u8, 1);

            vtrn1_s8x2_in_place(&left_bank_3_reg, &right_bank_3_reg);
            vtrn1_s8x2_in_place(&left_bank_4_reg, &right_bank_4_reg);

            acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg);
            acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_2_reg);
            acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg);
            acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_3_reg);
            acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg);
            acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_4_reg);

            // Fixed-point multiplication.
            acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
            acc0 = DivideByPOT<DepthwiseConvOutputRounding::kUpward>::Run(
                acc0, -output_shift);
            acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
            acc1 = DivideByPOT<DepthwiseConvOutputRounding::kUpward>::Run(
                acc1, -output_shift);
            // Add the output offset.
            acc_s16_0_1 = vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1));
            acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec);
            // Apply the activation function.
            acc_u8 = vqmovun_s16(acc_s16_0_1);
            acc_u8 = vmax_u8(acc_u8, output_activation_min_vec);
            acc_u8 = vmin_u8(acc_u8, output_activation_max_vec);

            vst1_lane_8x4(output_data_base + depth, acc_u8, 0);
            vst1_lane_8x4(output_data_base + depth + output_height_stride,
                          acc_u8, 1);

            left_bank_0_reg = right_bank_0_reg;
            left_bank_1_reg = right_bank_1_reg;
            left_bank_2_reg = right_bank_2_reg;
            left_bank_3_reg = right_bank_3_reg;
            left_bank_4_reg = right_bank_4_reg;
          }
          for (; i_width < output_width_overall_micro_repeats; ++i_width) {
            TFLITE_DCHECK_NE(residual_width, kFourOverStride);

            // No need to load next ("right") block of data.

            uint8* output_data_base = output_data + depth * 2 * i_width + 4 * s;

            // Iterate over input width shifts within 4x4 blocks.
            {
              acc0 = adjusted_bias_data;
              acc1 = adjusted_bias_data;

              acc0 = vdotq_s32(acc0, filter_reg_0_a, left_bank_0_reg);
              acc0 = vdotq_s32(acc0, filter_reg_1_a, left_bank_1_reg);
              acc0 = vdotq_s32(acc0, filter_reg_2_a, left_bank_2_reg);
              acc1 = vdotq_s32(acc1, filter_reg_0_a, left_bank_2_reg);
              acc1 = vdotq_s32(acc1, filter_reg_1_a, left_bank_3_reg);
              acc1 = vdotq_s32(acc1, filter_reg_2_a, left_bank_4_reg);

              // Fixed-point multiplication.
              acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
              acc0 = DivideByPOT<DepthwiseConvOutputRounding::kUpward>::Run(
                  acc0, -output_shift);
              acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
              acc1 = DivideByPOT<DepthwiseConvOutputRounding::kUpward>::Run(
                  acc1, -output_shift);
              // Add the output offset.
              int16x8_t acc_s16_0_1 =
                  vcombine_s16(vqmovn_s32(acc0), vqmovn_s32(acc1));
              acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec);
              // Apply the activation function.
              uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1);
              acc_u8 = vmax_u8(acc_u8, output_activation_min_vec);
              acc_u8 = vmin_u8(acc_u8, output_activation_max_vec);

              vst1_lane_8x4(output_data_base, acc_u8, 0);
              vst1_lane_8x4(output_data_base + output_height_stride, acc_u8, 1);

              left_bank_0_reg = vrev32q_u16(left_bank_0_reg);
              left_bank_1_reg = vrev32q_u16(left_bank_1_reg);
              left_bank_2_reg = vrev32q_u16(left_bank_2_reg);
              left_bank_3_reg = vrev32q_u16(left_bank_3_reg);
              left_bank_4_reg = vrev32q_u16(left_bank_4_reg);
              vtrn1_s8x2_in_place(&left_bank_0_reg, &right_bank_0_reg);
              vtrn1_s8x2_in_place(&left_bank_1_reg, &right_bank_1_reg);
              vtrn1_s8x2_in_place(&left_bank_2_reg, &right_bank_2_reg);
              vtrn1_s8x2_in_place(&left_bank_3_reg, &right_bank_3_reg);
              vtrn1_s8x2_in_place(&left_bank_4_reg, &right_bank_4_reg);
            }
          }
          bias_data += kBiasIncrement;
        }
      } else {
        // block_height == 1.
        int8x16_t filter_reg_0_a;
        int8x16_t filter_reg_1_a;
        int8x16_t filter_reg_2_a;
        int8x16_t filter_reg_0_b;
        int8x16_t filter_reg_1_b;
        int8x16_t filter_reg_2_b;

        filter_reg_0_a = vld1q_s8(filter_block);
        filter_reg_1_a = vld1q_s8(filter_block + 32);
        filter_reg_2_a = vld1q_s8(filter_block + 64);
        filter_reg_0_b = vld1q_s8(filter_block + 16);
        filter_reg_1_b = vld1q_s8(filter_block + 16 + 32);
        filter_reg_2_b = vld1q_s8(filter_block + 16 + 64);

        const int8* scratch_data =
            scratch_block_data + depth_micro_stride * j_depth;
        uint8* output_data = output_block_data + 8 * j_depth;
        const int8* input_data_0 = scratch_data;

        const int32x4_t adjusted_bias_data_a = vld1q_s32(bias_data);
        bias_data += kBiasIncrement;
        const int32x4_t adjusted_bias_data_b = vld1q_s32(bias_data);
        bias_data += kBiasIncrement;

        // Load first sub-micro block of data into operational banks.
        int8x16_t left_bank_0_reg_a = vld1q_s8(input_data_0);
        int8x16_t left_bank_1_reg_a =
            vld1q_s8(input_data_0 + workspace_height_stride);
        int8x16_t left_bank_2_reg_a =
            vld1q_s8(input_data_0 + 2 * workspace_height_stride);
        int8x16_t left_bank_0_reg_b = vld1q_s8(input_data_0 + 16);
        int8x16_t left_bank_1_reg_b =
            vld1q_s8(input_data_0 + workspace_height_stride + 16);
        int8x16_t left_bank_2_reg_b =
            vld1q_s8(input_data_0 + 2 * workspace_height_stride + 16);

        int8x16_t right_bank_0_reg_a;
        int8x16_t right_bank_1_reg_a;
        int8x16_t right_bank_2_reg_a;
        int8x16_t right_bank_0_reg_b;
        int8x16_t right_bank_1_reg_b;
        int8x16_t right_bank_2_reg_b;

        int32x4_t acc0_a;
        int32x4_t acc0_b;

        for (int i_width = 0; i_width < output_width_overall_micro_repeats;
             ++i_width) {
          const int output_width = i_width == output_width_micro_repeats
                                       ? residual_width
                                       : kFourOverStride;
          TFLITE_DCHECK_LE(output_width * kStrideVal, 4);
          const int8* input_data = input_data_0 + width_micro_stride * i_width;
          const bool no_right_block = i_width == output_width_micro_repeats &&
                                      output_width_overall_micro_repeats ==
                                          workspace_width_micro_repeats;

          if (!no_right_block) {
            // Load next sub-micro block of data.
            right_bank_0_reg_a = vld1q_s8(input_data + width_micro_stride);
            right_bank_1_reg_a = vld1q_s8(input_data + width_micro_stride +
                                          workspace_height_stride);
            right_bank_2_reg_a = vld1q_s8(input_data + width_micro_stride +
                                          2 * workspace_height_stride);
            right_bank_0_reg_b = vld1q_s8(input_data + width_micro_stride + 16);
            right_bank_1_reg_b = vld1q_s8(input_data + width_micro_stride +
                                          workspace_height_stride + 16);
            right_bank_2_reg_b = vld1q_s8(input_data + width_micro_stride +
                                          2 * workspace_height_stride + 16);
          }

          uint8* output_data_base = output_data + depth * 2 * i_width;

          // Iterate over input width shifts within 4x4 blocks.
          {
            acc0_a = adjusted_bias_data_a;
            acc0_b = adjusted_bias_data_b;

            acc0_a = vdotq_s32(acc0_a, filter_reg_0_a, left_bank_0_reg_a);
            acc0_a = vdotq_s32(acc0_a, filter_reg_1_a, left_bank_1_reg_a);
            acc0_a = vdotq_s32(acc0_a, filter_reg_2_a, left_bank_2_reg_a);
            acc0_b = vdotq_s32(acc0_b, filter_reg_0_b, left_bank_0_reg_b);
            acc0_b = vdotq_s32(acc0_b, filter_reg_1_b, left_bank_1_reg_b);
            acc0_b = vdotq_s32(acc0_b, filter_reg_2_b, left_bank_2_reg_b);

            // Fixed-point multiplication.
            acc0_a = vqrdmulhq_n_s32(acc0_a, output_multiplier);
            acc0_b = vqrdmulhq_n_s32(acc0_b, output_multiplier);
            acc0_a = DivideByPOT<DepthwiseConvOutputRounding::kUpward>::Run(
                acc0_a, -output_shift);
            acc0_b = DivideByPOT<DepthwiseConvOutputRounding::kUpward>::Run(
                acc0_b, -output_shift);
            // Add the output offset.
            int16x8_t acc_s16_0_1 =
                vcombine_s16(vqmovn_s32(acc0_a), vqmovn_s32(acc0_b));
            acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec);
            // Apply the activation function.
            uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1);
            acc_u8 = vmax_u8(acc_u8, output_activation_min_vec);
            acc_u8 = vmin_u8(acc_u8, output_activation_max_vec);

            vst1_u8(output_data_base, acc_u8);

            left_bank_0_reg_a = vrev32q_u16(left_bank_0_reg_a);
            left_bank_1_reg_a = vrev32q_u16(left_bank_1_reg_a);
            left_bank_2_reg_a = vrev32q_u16(left_bank_2_reg_a);
            left_bank_0_reg_b = vrev32q_u16(left_bank_0_reg_b);
            left_bank_1_reg_b = vrev32q_u16(left_bank_1_reg_b);
            left_bank_2_reg_b = vrev32q_u16(left_bank_2_reg_b);
            vtrn1_s8x2_in_place(&left_bank_0_reg_a, &right_bank_0_reg_a);
            vtrn1_s8x2_in_place(&left_bank_1_reg_a, &right_bank_1_reg_a);
            vtrn1_s8x2_in_place(&left_bank_2_reg_a, &right_bank_2_reg_a);
            vtrn1_s8x2_in_place(&left_bank_0_reg_b, &right_bank_0_reg_b);
            vtrn1_s8x2_in_place(&left_bank_1_reg_b, &right_bank_1_reg_b);
            vtrn1_s8x2_in_place(&left_bank_2_reg_b, &right_bank_2_reg_b);
          }

          if (output_width > 1) {
            acc0_a = adjusted_bias_data_a;
            acc0_b = adjusted_bias_data_b;

            acc0_a = vdotq_s32(acc0_a, filter_reg_0_a, left_bank_0_reg_a);
            acc0_a = vdotq_s32(acc0_a, filter_reg_1_a, left_bank_1_reg_a);
            acc0_a = vdotq_s32(acc0_a, filter_reg_2_a, left_bank_2_reg_a);
            acc0_b = vdotq_s32(acc0_b, filter_reg_0_b, left_bank_0_reg_b);
            acc0_b = vdotq_s32(acc0_b, filter_reg_1_b, left_bank_1_reg_b);
            acc0_b = vdotq_s32(acc0_b, filter_reg_2_b, left_bank_2_reg_b);

            // Fixed-point multiplication.
            acc0_a = vqrdmulhq_n_s32(acc0_a, output_multiplier);
            acc0_b = vqrdmulhq_n_s32(acc0_b, output_multiplier);
            acc0_a = DivideByPOT<DepthwiseConvOutputRounding::kUpward>::Run(
                acc0_a, -output_shift);
            acc0_b = DivideByPOT<DepthwiseConvOutputRounding::kUpward>::Run(
                acc0_b, -output_shift);
            // Add the output offset.
            int16x8_t acc_s16_0_1 =
                vcombine_s16(vqmovn_s32(acc0_a), vqmovn_s32(acc0_b));
            acc_s16_0_1 = vqaddq_s16(acc_s16_0_1, output_offset_vec);
            // Apply the activation function.
            uint8x8_t acc_u8 = vqmovun_s16(acc_s16_0_1);
            acc_u8 = vmax_u8(acc_u8, output_activation_min_vec);
            acc_u8 = vmin_u8(acc_u8, output_activation_max_vec);

            vst1_u8(output_data_base + depth, acc_u8);

            left_bank_0_reg_a = right_bank_0_reg_a;
            left_bank_1_reg_a = right_bank_1_reg_a;
            left_bank_2_reg_a = right_bank_2_reg_a;
            left_bank_0_reg_b = right_bank_0_reg_b;
            left_bank_1_reg_b = right_bank_1_reg_b;
            left_bank_2_reg_b = right_bank_2_reg_b;
          }
        }
      }
    }
  }  // NOLINT(readability/fn_size) Manually unrolled.