TfLiteStatus EvalMean()

in tensorflow/lite/micro/kernels/reduce.cc [150:270]


TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
  const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
  TfLiteReducerParams* params =
      reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
  OpData* op_data = reinterpret_cast<OpData*>(node->user_data);

  int num_axis = static_cast<int>(ElementCount(*axis->dims));
  int temp_index[kMaxNumberOfAxis];
  int resolved_axis[kMaxNumberOfReducedAxis];

  tflite::MeanParams op_params;
  ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis, &op_params);

  // Special case mean implementation exists for 4D mean across axes 1 and 2.
  bool special_case_4d_axes_1_and_2 =
      input->dims->size == 4 && op_params.axis_count == 2 &&
      ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
       (op_params.axis[0] == 2 && op_params.axis[1] == 1));

  switch (input->type) {
    case kTfLiteFloat32: {
      // Defer to specialized implementation for 4D Mean across axes 1 & 2.
      if (params->keep_dims && special_case_4d_axes_1_and_2) {
        reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
                            tflite::micro::GetTensorData<float>(input),
                            tflite::micro::GetTensorShape(output),
                            tflite::micro::GetTensorData<float>(output));
      } else {
        TF_LITE_ENSURE(
            context,
            reference_ops::Mean(
                tflite::micro::GetTensorData<float>(input), input->dims->data,
                input->dims->size, tflite::micro::GetTensorData<float>(output),
                output->dims->data, output->dims->size,
                tflite::micro::GetTensorData<int>(axis), num_axis,
                params->keep_dims, temp_index, resolved_axis,
                tflite::micro::GetTensorData<float>(output)));
      }
    } break;
    case kTfLiteInt8: {
      // Defer to specialized implementation for 4D Mean across axes 1 & 2.
      if (params->keep_dims && special_case_4d_axes_1_and_2) {
        reference_integer_ops::Mean(
            op_params, op_data->multiplier, op_data->shift,
            tflite::micro::GetTensorShape(input),
            tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
            tflite::micro::GetTensorShape(output),
            tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp);
      } else if (op_data->input_zp == op_data->output_zp &&
                 op_data->input_scale == op_data->output_scale) {
        int32_t* temp_buffer = static_cast<int32_t*>(
            context->GetScratchBuffer(context, op_data->temp_buffer_idx));
        TF_LITE_ENSURE(
            context,
            reference_ops::Mean(
                tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
                input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
                output->dims->data, output->dims->size,
                tflite::micro::GetTensorData<int>(axis), num_axis,
                params->keep_dims, temp_index, resolved_axis, temp_buffer));
      } else {
        int32_t* temp_buffer = static_cast<int32_t*>(
            context->GetScratchBuffer(context, op_data->temp_buffer_idx));
        TF_LITE_ENSURE(
            context,
            reference_ops::QuantizedMeanOrSum(
                tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
                op_data->input_scale, input->dims->data, input->dims->size,
                tflite::micro::GetTensorData<int8_t>(output),
                op_data->output_zp, op_data->output_scale, output->dims->data,
                output->dims->size, tflite::micro::GetTensorData<int>(axis),
                num_axis, params->keep_dims, temp_index, resolved_axis,
                temp_buffer, false));
      }
    } break;
    case kTfLiteInt16: {
      // Defer to specialized implementation for 4D Mean across axes 1 & 2.
      if (params->keep_dims && special_case_4d_axes_1_and_2) {
        reference_integer_ops::Mean(
            op_params, op_data->multiplier, op_data->shift,
            tflite::micro::GetTensorShape(input),
            tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
            tflite::micro::GetTensorShape(output),
            tflite::micro::GetTensorData<int16_t>(output), op_data->output_zp);
      } else if (op_data->input_zp == op_data->output_zp &&
                 op_data->input_scale == op_data->output_scale) {
        int32_t* temp_buffer = static_cast<int32_t*>(
            context->GetScratchBuffer(context, op_data->temp_buffer_idx));
        TF_LITE_ENSURE(
            context,
            reference_ops::Mean(tflite::micro::GetTensorData<int16_t>(input),
                                input->dims->data, input->dims->size,
                                tflite::micro::GetTensorData<int16_t>(output),
                                output->dims->data, output->dims->size,
                                tflite::micro::GetTensorData<int>(axis),
                                num_axis, params->keep_dims, temp_index,
                                resolved_axis, temp_buffer));
      } else {
        int32_t* temp_buffer = static_cast<int32_t*>(
            context->GetScratchBuffer(context, op_data->temp_buffer_idx));
        TF_LITE_ENSURE(
            context,
            reference_ops::QuantizedMeanOrSum(
                tflite::micro::GetTensorData<int16_t>(input), op_data->input_zp,
                op_data->input_scale, input->dims->data, input->dims->size,
                tflite::micro::GetTensorData<int16_t>(output),
                op_data->output_zp, op_data->output_scale, output->dims->data,
                output->dims->size, tflite::micro::GetTensorData<int>(axis),
                num_axis, params->keep_dims, temp_index, resolved_axis,
                temp_buffer, false));
      }
    } break;
    default:
      TF_LITE_ENSURE_MSG(context, false,
                         "Currently, only float32, int8 or uint8 input type "
                         "is supported.");
  }
  return kTfLiteOk;
}