TfLiteStatus EvalQuantizeReference()

in src/tensorflow/lite/micro/kernels/quantize_common.cpp [83:205]


TfLiteStatus EvalQuantizeReference(TfLiteContext* context, TfLiteNode* node) {
  TFLITE_DCHECK(node->user_data != nullptr);
  auto* data = static_cast<OpDataQuantizeReference*>(node->user_data);

  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);

  if (input->type == kTfLiteFloat32) {
    switch (output->type) {
      case kTfLiteInt8:
        reference_ops::AffineQuantize(
            data->quantization_params, tflite::micro::GetTensorShape(input),
            tflite::micro::GetTensorData<float>(input),
            tflite::micro::GetTensorShape(output),
            tflite::micro::GetTensorData<int8_t>(output));
        break;
      case kTfLiteInt16:
        reference_ops::AffineQuantize(
            data->quantization_params, tflite::micro::GetTensorShape(input),
            tflite::micro::GetTensorData<float>(input),
            tflite::micro::GetTensorShape(output),
            tflite::micro::GetTensorData<int16_t>(output));
        return kTfLiteOk;
      default:
        TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
                           TfLiteTypeGetName(input->type),
                           TfLiteTypeGetName(output->type));
        return kTfLiteError;
    }
  } else if (input->type == kTfLiteInt32) {
    size_t size = ElementCount(*input->dims);
    switch (output->type) {
      case kTfLiteInt8:
        reference_ops::Requantize(
            tflite::micro::GetTensorData<int32_t>(input), size,
            data->requantize_output_multiplier, data->requantize_output_shift,
            data->input_zero_point, data->quantization_params.zero_point,
            tflite::micro::GetTensorData<int8_t>(output));
        break;
      case kTfLiteInt16:
        reference_ops::Requantize(
            tflite::micro::GetTensorData<int32_t>(input), size,
            data->requantize_output_multiplier, data->requantize_output_shift,
            data->input_zero_point, data->quantization_params.zero_point,
            tflite::micro::GetTensorData<int16_t>(output));
        break;
      default:
        TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
                           TfLiteTypeGetName(input->type),
                           TfLiteTypeGetName(output->type));
        return kTfLiteError;
    }
  } else if (input->type == kTfLiteInt16) {
    size_t size = ElementCount(*input->dims);
    switch (output->type) {
      case kTfLiteInt8:
        reference_ops::Requantize(
            tflite::micro::GetTensorData<int16_t>(input), size,
            data->requantize_output_multiplier, data->requantize_output_shift,
            data->input_zero_point, data->quantization_params.zero_point,
            tflite::micro::GetTensorData<int8_t>(output));
        break;
      case kTfLiteInt16:
        reference_ops::Requantize(
            tflite::micro::GetTensorData<int16_t>(input), size,
            data->requantize_output_multiplier, data->requantize_output_shift,
            data->input_zero_point, data->quantization_params.zero_point,
            tflite::micro::GetTensorData<int16_t>(output));
        return kTfLiteOk;
      case kTfLiteInt32:
        reference_ops::Requantize(
            tflite::micro::GetTensorData<int16_t>(input), size,
            data->requantize_output_multiplier, data->requantize_output_shift,
            data->input_zero_point, data->quantization_params.zero_point,
            tflite::micro::GetTensorData<int32_t>(output));
        return kTfLiteOk;
      default:
        TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
                           TfLiteTypeGetName(input->type),
                           TfLiteTypeGetName(output->type));
        return kTfLiteError;
    }
  } else if (input->type == kTfLiteInt8) {
    // Int8 to Int8 requantization, required if the input and output tensors
    // have different scales and/or zero points.
    size_t size = ElementCount(*input->dims);
    switch (output->type) {
      case kTfLiteInt8:
        reference_ops::Requantize(
            tflite::micro::GetTensorData<int8_t>(input), size,
            data->requantize_output_multiplier, data->requantize_output_shift,
            data->input_zero_point, data->quantization_params.zero_point,
            tflite::micro::GetTensorData<int8_t>(output));
        break;
      case kTfLiteInt16:
        reference_ops::Requantize(
            tflite::micro::GetTensorData<int8_t>(input), size,
            data->requantize_output_multiplier, data->requantize_output_shift,
            data->input_zero_point, data->quantization_params.zero_point,
            tflite::micro::GetTensorData<int16_t>(output));
        break;
      case kTfLiteInt32:
        reference_ops::Requantize(
            tflite::micro::GetTensorData<int8_t>(input), size,
            data->requantize_output_multiplier, data->requantize_output_shift,
            data->input_zero_point, data->quantization_params.zero_point,
            tflite::micro::GetTensorData<int32_t>(output));
        break;
      default:
        TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
                           TfLiteTypeGetName(input->type),
                           TfLiteTypeGetName(output->type));
        return kTfLiteError;
    }
  } else {
    TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
                       TfLiteTypeGetName(input->type),
                       TfLiteTypeGetName(output->type));
    return kTfLiteError;
  }

  return kTfLiteOk;
}