TfLiteStatus EqualEval()

in src/tensorflow/lite/micro/kernels/comparisons.cpp [37:126]


TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
  TFLITE_DCHECK(node->user_data != nullptr);
  const OpData* data = static_cast<const OpData*>(node->user_data);

  const TfLiteEvalTensor* input1 =
      tflite::micro::GetEvalInput(context, node, kInputTensor1);
  const TfLiteEvalTensor* input2 =
      tflite::micro::GetEvalInput(context, node, kInputTensor2);
  TfLiteEvalTensor* output =
      tflite::micro::GetEvalOutput(context, node, kOutputTensor);

  RuntimeShape input1_shape = tflite::micro::GetTensorShape(input1);
  RuntimeShape input2_shape = tflite::micro::GetTensorShape(input2);
  RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
  bool* output_data = tflite::micro::GetTensorData<bool>(output);

  bool requires_broadcast = !tflite::micro::HaveSameShapes(input1, input2);
  switch (input1->type) {
    case kTfLiteBool:
      requires_broadcast
          ? reference_ops::Broadcast4DSlowEqualNoScaling(
                data->params, input1_shape,
                tflite::micro::GetTensorData<bool>(input1), input2_shape,
                tflite::micro::GetTensorData<bool>(input2), output_shape,
                output_data)
          : reference_ops::EqualNoScaling(
                data->params, input1_shape,
                tflite::micro::GetTensorData<bool>(input1), input2_shape,
                tflite::micro::GetTensorData<bool>(input2), output_shape,
                output_data);
      break;
    case kTfLiteFloat32:
      requires_broadcast
          ? reference_ops::Broadcast4DSlowEqualNoScaling(
                data->params, input1_shape,
                tflite::micro::GetTensorData<float>(input1), input2_shape,
                tflite::micro::GetTensorData<float>(input2), output_shape,
                output_data)
          : reference_ops::EqualNoScaling(
                data->params, input1_shape,
                tflite::micro::GetTensorData<float>(input1), input2_shape,
                tflite::micro::GetTensorData<float>(input2), output_shape,
                output_data);
      break;
    case kTfLiteInt32:
      requires_broadcast
          ? reference_ops::Broadcast4DSlowEqualNoScaling(
                data->params, input1_shape,
                tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
                tflite::micro::GetTensorData<int32_t>(input2), output_shape,
                output_data)
          : reference_ops::EqualNoScaling(
                data->params, input1_shape,
                tflite::micro::GetTensorData<int32_t>(input1), input2_shape,
                tflite::micro::GetTensorData<int32_t>(input2), output_shape,
                output_data);
      break;
    case kTfLiteInt64:
      requires_broadcast
          ? reference_ops::Broadcast4DSlowEqualNoScaling(
                data->params, input1_shape,
                tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
                tflite::micro::GetTensorData<int64_t>(input2), output_shape,
                output_data)
          : reference_ops::EqualNoScaling(
                data->params, input1_shape,
                tflite::micro::GetTensorData<int64_t>(input1), input2_shape,
                tflite::micro::GetTensorData<int64_t>(input2), output_shape,
                output_data);
      break;
    case kTfLiteInt8:
      requires_broadcast
          ? reference_ops::Broadcast4DSlowEqualWithScaling(
                data->params, input1_shape,
                tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
                tflite::micro::GetTensorData<int8_t>(input2), output_shape,
                output_data)
          : reference_ops::EqualWithScaling(
                data->params, input1_shape,
                tflite::micro::GetTensorData<int8_t>(input1), input2_shape,
                tflite::micro::GetTensorData<int8_t>(input2), output_shape,
                output_data);
      break;
    default:
      TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
                         TfLiteTypeGetName(input1->type), input1->type);
      return kTfLiteError;
  }
  return kTfLiteOk;
}