TfLiteStatus Prepare()

in tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc [170:354]


TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  TFLITE_DCHECK(node->user_data != nullptr);
  TFLITE_DCHECK(node->builtin_data != nullptr);

  auto* params =
      reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
  OpData* data = static_cast<OpData*>(node->user_data);

  TfLiteTensor* output = AllocateTempOutputTensor(node, kOutputTensor);
  const TfLiteTensor* input = AllocateTempInputTensor(node, kInputTensor);
  const TfLiteTensor* filter = AllocateTempInputTensor(node, kFilterTensor);
  const TfLiteTensor* bias =
      AllocateTempInputTensor(context, node, kBiasTensor);

  const TfLiteType data_type = input->type;
  int width = SizeOfDimension(input, 2);
  int height = SizeOfDimension(input, 1);

#if defined(MLI_2_0) && !defined(MLI_2_0_KRNL_TEST)
  int filter_width = SizeOfDimension(filter, 1);
  int filter_height = SizeOfDimension(filter, 0);
#else
  int filter_width = SizeOfDimension(filter, 2);
  int filter_height = SizeOfDimension(filter, 1);
#endif

  // Per channel quantization is only needed for int8 inference. For other
  // quantized types, only a single scale and zero point is needed.
  const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
  // Dynamically allocate per-channel quantization parameters.
  data->per_channel_output_multiplier =
      reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
          context, num_channels * sizeof(int32_t)));
  data->per_channel_output_shift =
      reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
          context, num_channels * sizeof(int32_t)));

  data->is_mli_applicable =
      IsMliApplicable(context, input, filter, bias, params);

  // All per-channel quantized tensors need valid zero point and scale arrays.
  if (input->type == kTfLiteInt8) {
    TF_LITE_ENSURE_EQ(context, filter->quantization.type,
                      kTfLiteAffineQuantization);

    const auto* affine_quantization =
        reinterpret_cast<TfLiteAffineQuantization*>(
            filter->quantization.params);
    TF_LITE_ENSURE(context, affine_quantization);
    TF_LITE_ENSURE(context, affine_quantization->scale);
    TF_LITE_ENSURE(context, affine_quantization->zero_point);
    TF_LITE_ENSURE(
        context, affine_quantization->scale->size == 1 ||
                     affine_quantization->scale->size ==
                         filter->dims->data[kDepthwiseConvQuantizedDimension]);
    TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
                      affine_quantization->zero_point->size);
  }

  TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height,
                                        filter_width, filter_height, data_type,
                                        data));

  data->input_zero_point = input->params.zero_point;
  data->filter_zero_point = filter->params.zero_point;
  data->output_zero_point = output->params.zero_point;

  if (data->is_mli_applicable) {
    data->mli_in = ops::micro::MliTensorInterface(static_cast<mli_tensor*>(
        context->AllocatePersistentBuffer(context, sizeof(mli_tensor))));
    data->mli_weights = ops::micro::MliTensorInterface(static_cast<mli_tensor*>(
        context->AllocatePersistentBuffer(context, sizeof(mli_tensor))));
    data->mli_bias = ops::micro::MliTensorInterface(static_cast<mli_tensor*>(
        context->AllocatePersistentBuffer(context, sizeof(mli_tensor))));
    data->mli_out = ops::micro::MliTensorInterface(static_cast<mli_tensor*>(
        context->AllocatePersistentBuffer(context, sizeof(mli_tensor))));
    data->cfg = static_cast<mli_conv2d_cfg*>(
        context->AllocatePersistentBuffer(context, sizeof(mli_conv2d_cfg)));

#ifdef MLI_2_0
    const int num_buffers = 2;
    data->per_channel_scale_frac_bits =
        static_cast<int8_t*>(context->AllocatePersistentBuffer(
            context, num_buffers * num_channels * sizeof(int16_t)));
#endif

    // Reuse space allocated for OpData parameters.
#ifdef MLI_2_0
    *data->mli_weights.Scale<int16_t**>() =
        reinterpret_cast<int16_t*>(data->per_channel_output_multiplier);
    *data->mli_bias.Scale<int16_t**>() =
        reinterpret_cast<int16_t*>(data->per_channel_output_multiplier) +
        num_channels;
#else
    *data->mli_weights.Scale<int32_t**>() =
        static_cast<int32_t*>(data->per_channel_output_multiplier);
    *data->mli_bias.Scale<int32_t**>() =
        static_cast<int32_t*>(data->per_channel_output_shift);
#endif

#ifdef MLI_2_0
    *data->mli_weights.ZeroPoint<int16_t**>() =
        reinterpret_cast<int16_t*>(data->per_channel_output_shift);
    *data->mli_bias.ZeroPoint<int16_t**>() =
        reinterpret_cast<int16_t*>(data->per_channel_output_shift) +
        num_channels;
#else
    *data->mli_weights.ZeroPoint<int16_t**>() =
        reinterpret_cast<int16_t*>(&data->filter_zero_point);
    *data->mli_bias.ZeroPoint<int16_t**>() =
        reinterpret_cast<int16_t*>(&data->filter_zero_point) + sizeof(int16_t);
#endif

#ifdef MLI_2_0
    *data->mli_weights.ScaleFracBits<int8_t**>() =
        reinterpret_cast<int8_t*>(data->per_channel_scale_frac_bits);
    *data->mli_bias.ScaleFracBits<int8_t**>() =
        reinterpret_cast<int8_t*>(data->per_channel_scale_frac_bits) +
        num_channels;
#endif

    ops::micro::ConvertToMliTensor(input, &data->mli_in);
    ops::micro::ConvertToMliTensorPerChannel(filter, &data->mli_weights,
                                             /* is_bias_tensor = */ false);
    ops::micro::ConvertToMliTensorPerChannel(bias, &data->mli_bias,
                                             /* is_bias_tensor = */ true);
#ifdef MLI_2_0
    ops::micro::AdjustBiasTensor(&data->mli_bias, &data->mli_in,
                                 &data->mli_weights);
#endif
    ops::micro::ConvertToMliTensor(output, &data->mli_out);

#ifdef MLI_2_0
    // Choose group convolution function for "channel multiplier" functionality.
    const int in_ch = SizeOfDimension(input, 3);
    const int filters_num = SizeOfDimension(filter, 3);
    const int channels_num = SizeOfDimension(filter, 2);
    if (in_ch == filters_num && channels_num == 1) {
      data->p_mli_krn_depthwise_conv2d_sa8_sa8_sa32 =
          mli_krn_depthwise_conv2d(data->mli_weights.MliTensor());
    } else {
      data->p_mli_krn_depthwise_conv2d_sa8_sa8_sa32 =
          mli_krn_group_conv2d(data->mli_weights.MliTensor());
    }
#else
    data->p_mli_krn_depthwise_conv2d_sa8_sa8_sa32 =
        mli_krn_depthwise_conv2d(data->mli_weights.MliTensor(), data->cfg);
#endif

#ifdef MLI_2_0
    data->cfg->dilation_width = 1;
    data->cfg->dilation_height = 1;
#endif

    if (data->output_activation_min == -128 &&
        data->output_activation_max == 127) {
      data->cfg->relu.type = MLI_RELU_NONE;
    } else if (params->activation == kTfLiteActRelu) {
      data->cfg->relu.type = MLI_RELU_GEN;
    } else if (params->activation == kTfLiteActRelu6) {
      data->cfg->relu.type = MLI_RELU_6;
    } else if (params->activation == kTfLiteActReluN1To1) {
      data->cfg->relu.type = MLI_RELU_1;
    } else {
      data->cfg->relu.type = MLI_RELU_NONE;
    }

    data->cfg->stride_width = params->stride_width;
    data->cfg->stride_height = params->stride_height;
    if (params->padding == kTfLitePaddingValid) {
      data->cfg->padding_left = 0;
      data->cfg->padding_right = 0;
      data->cfg->padding_top = 0;
      data->cfg->padding_bottom = 0;
    } else {
      data->cfg->padding_left = data->padding.width;
      data->cfg->padding_right =
          data->padding.width + data->padding.width_offset;
      data->cfg->padding_top = data->padding.height;
      data->cfg->padding_bottom =
          data->padding.height + data->padding.height_offset;
    }
  }
  return kTfLiteOk;
}