in src/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cpp [46:153]
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
const auto& params =
*(reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data));
const TfLiteTensor* input =
GetInput(context, node, kDepthwiseConvInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter =
GetInput(context, node, kDepthwiseConvWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
TfLiteTensor* output = GetOutput(context, node, kDepthwiseConvOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
const TfLiteType data_type = input->type;
int input_width = SizeOfDimension(input, 2);
int input_height = SizeOfDimension(input, 1);
int filter_width = SizeOfDimension(filter, 2);
int filter_height = SizeOfDimension(filter, 1);
int output_width = SizeOfDimension(output, 2);
int output_height = SizeOfDimension(output, 1);
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
kTfLiteAffineQuantization);
// All per-channel quantized tensors need valid zero point and scale arrays.
const auto* affine_quantization =
reinterpret_cast<TfLiteAffineQuantization*>(
filter->quantization.params);
TF_LITE_ENSURE(context, affine_quantization);
TF_LITE_ENSURE(context, affine_quantization->scale);
TF_LITE_ENSURE(context, affine_quantization->zero_point);
TF_LITE_ENSURE(
context, affine_quantization->scale->size == 1 ||
affine_quantization->scale->size ==
filter->dims->data[kDepthwiseConvQuantizedDimension]);
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
affine_quantization->zero_point->size);
// Allocate memory for per-channel quantization parameters
const int num_channels =
filter->dims->data[kDepthwiseConvQuantizedDimension];
data->reference_op_data.per_channel_output_multiplier =
reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
context, num_channels * sizeof(int32_t)));
data->reference_op_data.per_channel_output_shift =
reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
context, num_channels * sizeof(int32_t)));
}
TF_LITE_ENSURE_STATUS(CalculateOpDataDepthwiseConv(
context, node, params, input_width, input_height, filter_width,
filter_height, output_width, output_height, data_type,
&data->reference_op_data));
if (input->type == kTfLiteInt8) {
RuntimeShape input_shape = GetTensorShape(input);
RuntimeShape output_shape = GetTensorShape(output);
RuntimeShape filter_shape = GetTensorShape(filter);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
const int output_depth = MatchingDim(output_shape, 3, filter_shape, 3);
TFLITE_DCHECK_EQ(batch_size, 1); /* Only batch = 1 is supported */
cmsis_nn_dims input_dims;
input_dims.n = batch_size;
input_dims.h = input_height;
input_dims.w = input_width;
input_dims.c = input_shape.Dims(3);
cmsis_nn_dims filter_dims;
filter_dims.n = 1;
filter_dims.h = filter_height;
filter_dims.w = filter_width;
filter_dims.c = output_depth;
cmsis_nn_dims output_dims;
output_dims.n = batch_size;
output_dims.h = output_height;
output_dims.w = output_width;
output_dims.c = output_depth;
cmsis_nn_dw_conv_params dw_conv_params;
dw_conv_params.padding.h = data->reference_op_data.padding.height;
dw_conv_params.padding.w = data->reference_op_data.padding.width;
dw_conv_params.dilation.h = params.dilation_height_factor;
dw_conv_params.dilation.w = params.dilation_width_factor;
const int32_t buf_size = arm_depthwise_conv_wrapper_s8_get_buffer_size(
&dw_conv_params, &input_dims, &filter_dims, &output_dims);
if (buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, buf_size, &data->buffer_idx));
} else {
data->buffer_idx = -1;
}
}
return kTfLiteOk;
}