TfLiteStatus Prepare()

in tensorflow/lite/micro/kernels/arc_mli/conv.cc [174:355]


TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  TFLITE_DCHECK(node->user_data != nullptr);
  TFLITE_DCHECK(node->builtin_data != nullptr);

  OpData* data = static_cast<OpData*>(node->user_data);
  const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);

  MicroContext* micro_context = GetMicroContext(context);

  TfLiteTensor* output =
      micro_context->AllocateTempOutputTensor(node, kOutputTensor);
  TfLiteTensor* input =
      micro_context->AllocateTempInputTensor(node, kInputTensor);
  TfLiteTensor* filter =
      micro_context->AllocateTempInputTensor(node, kFilterTensor);
  TfLiteTensor* bias =
      micro_context->AllocateTempInputTensor(context, node, kBiasTensor);

  int input_width = input->dims->data[2];
  int input_height = input->dims->data[1];
#if defined(MLI_2_0) && !defined(MLI_2_0_KRNL_TEST)
  int filter_width = filter->dims->data[1];
  int filter_height = filter->dims->data[0];
#else
  int filter_width = filter->dims->data[2];
  int filter_height = filter->dims->data[1];
#endif
  int output_width = output->dims->data[2];
  int output_height = output->dims->data[1];

  // Dynamically allocate per-channel quantization parameters.
  const int num_channels = filter->dims->data[kConvQuantizedDimension];
  data->per_channel_output_multiplier =
      reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
          context, num_channels * sizeof(int32_t)));
  data->per_channel_output_shift =
      reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
          context, num_channels * sizeof(int32_t)));

  data->is_mli_applicable =
      IsMliApplicable(context, input, filter, bias, params);

  // All per-channel quantized tensors need valid zero point and scale arrays.
  if (input->type == kTfLiteInt8) {
    TF_LITE_ENSURE_EQ(context, filter->quantization.type,
                      kTfLiteAffineQuantization);

    const auto* affine_quantization =
        static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
    TF_LITE_ENSURE(context, affine_quantization);
    TF_LITE_ENSURE(context, affine_quantization->scale);
    TF_LITE_ENSURE(context, affine_quantization->zero_point);

    TF_LITE_ENSURE(context,
                   affine_quantization->scale->size == 1 ||
                       affine_quantization->scale->size ==
                           filter->dims->data[kConvQuantizedDimension]);
    TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
                      affine_quantization->zero_point->size);
  }

  TF_LITE_ENSURE_STATUS(CalculateOpData(
      context, node, params, input_width, input_height, filter_width,
      filter_height, output_width, output_height, input->type, data));

  data->input_zero_point = input->params.zero_point;
  data->filter_zero_point = filter->params.zero_point;
  data->output_zero_point = output->params.zero_point;

  if (data->is_mli_applicable) {
    data->mli_in = ops::micro::MliTensorInterface(static_cast<mli_tensor*>(
        context->AllocatePersistentBuffer(context, sizeof(mli_tensor))));
    data->mli_weights = ops::micro::MliTensorInterface(static_cast<mli_tensor*>(
        context->AllocatePersistentBuffer(context, sizeof(mli_tensor))));
    data->mli_bias = ops::micro::MliTensorInterface(static_cast<mli_tensor*>(
        context->AllocatePersistentBuffer(context, sizeof(mli_tensor))));
    data->mli_out = ops::micro::MliTensorInterface(static_cast<mli_tensor*>(
        context->AllocatePersistentBuffer(context, sizeof(mli_tensor))));
    data->cfg = static_cast<mli_conv2d_cfg*>(
        context->AllocatePersistentBuffer(context, sizeof(mli_conv2d_cfg)));

#ifdef MLI_2_0
    data->per_channel_scale_frac_bits =
        static_cast<int8_t*>(context->AllocatePersistentBuffer(
            context, 2 * num_channels * sizeof(int16_t)));
#endif

    // Reuse space allocated for OpData parameters.
#ifdef MLI_2_0
    *data->mli_weights.Scale<int16_t**>() =
        reinterpret_cast<int16_t*>(data->per_channel_output_multiplier);
    *data->mli_bias.Scale<int16_t**>() =
        reinterpret_cast<int16_t*>(data->per_channel_output_multiplier) +
        num_channels;
#else
    *data->mli_weights.Scale<int32_t**>() =
        static_cast<int32_t*>(data->per_channel_output_multiplier);
    *data->mli_bias.Scale<int32_t**>() =
        static_cast<int32_t*>(data->per_channel_output_shift);
#endif

#ifdef MLI_2_0
    *data->mli_weights.ZeroPoint<int16_t**>() =
        reinterpret_cast<int16_t*>(data->per_channel_output_shift);
    *data->mli_bias.ZeroPoint<int16_t**>() =
        reinterpret_cast<int16_t*>(data->per_channel_output_shift) +
        num_channels;
#else
    *data->mli_weights.ZeroPoint<int16_t**>() =
        reinterpret_cast<int16_t*>(&data->filter_zero_point);
    *data->mli_bias.ZeroPoint<int16_t**>() =
        reinterpret_cast<int16_t*>(&data->filter_zero_point) + sizeof(int16_t);
#endif

#ifdef MLI_2_0
    *data->mli_weights.ScaleFracBits<int8_t**>() =
        reinterpret_cast<int8_t*>(data->per_channel_scale_frac_bits);
    *data->mli_bias.ScaleFracBits<int8_t**>() =
        reinterpret_cast<int8_t*>(data->per_channel_scale_frac_bits) +
        num_channels;
#endif

    ops::micro::ConvertToMliTensor(input, &data->mli_in);
    ops::micro::ConvertToMliTensorPerChannel(filter, &data->mli_weights,
                                             /* is_bias_tensor = */ false);
    ops::micro::ConvertToMliTensorPerChannel(bias, &data->mli_bias,
                                             /* is_bias_tensor = */ true);
#ifdef MLI_2_0
    ops::micro::AdjustBiasTensor(&data->mli_bias, &data->mli_in,
                                 &data->mli_weights);
#endif
    ops::micro::ConvertToMliTensor(output, &data->mli_out);

#ifdef MLI_2_0
    // Choose convolution mli specialized function.
    data->p_mli_krn_conv2d_sa8_sa8_sa32 =
        mli_krn_conv2d_hwcn(data->mli_weights.MliTensor());
#else
    data->p_mli_krn_conv2d_sa8_sa8_sa32 =
        mli_krn_conv2d_hwcn(data->mli_weights.MliTensor(), data->cfg);
#endif

#ifdef MLI_2_0
    data->cfg->dilation_width = 1;
    data->cfg->dilation_height = 1;
#endif

    if (data->output_activation_min == -128 &&
        data->output_activation_max == 127) {
      data->cfg->relu.type = MLI_RELU_NONE;
    } else if (params->activation == kTfLiteActRelu) {
      data->cfg->relu.type = MLI_RELU_GEN;
    } else if (params->activation == kTfLiteActRelu6) {
      data->cfg->relu.type = MLI_RELU_6;
    } else if (params->activation == kTfLiteActReluN1To1) {
      data->cfg->relu.type = MLI_RELU_1;
    } else {
      data->cfg->relu.type = MLI_RELU_NONE;
    }
    data->cfg->stride_width = params->stride_width;
    data->cfg->stride_height = params->stride_height;
    if (params->padding == kTfLitePaddingValid) {
      data->cfg->padding_left = 0;
      data->cfg->padding_right = 0;
      data->cfg->padding_top = 0;
      data->cfg->padding_bottom = 0;
    } else {
      data->cfg->padding_left = data->padding.width;
      data->cfg->padding_right =
          data->padding.width + data->padding.width_offset;
      data->cfg->padding_top = data->padding.height;
      data->cfg->padding_bottom =
          data->padding.height + data->padding.height_offset;
    }
  }

  micro_context->DeallocateTempTfLiteTensor(output);
  micro_context->DeallocateTempTfLiteTensor(input);
  micro_context->DeallocateTempTfLiteTensor(filter);
  micro_context->DeallocateTempTfLiteTensor(bias);
  return kTfLiteOk;
}