absl::Status ObjectDetector::CheckAndSetOutputs()

in tensorflow_lite_support/cc/task/vision/object_detector.cc [444:542]


absl::Status ObjectDetector::CheckAndSetOutputs() {
  // First, sanity checks on the model itself.
  const TfLiteEngine::Interpreter* interpreter =
      GetTfLiteEngine()->interpreter();
  // Check the number of output tensors.
  if (TfLiteEngine::OutputCount(interpreter) != 4) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrFormat("Mobile SSD models are expected to have exactly 4 "
                        "outputs, found %d",
                        TfLiteEngine::OutputCount(interpreter)),
        TfLiteSupportStatus::kInvalidNumOutputTensorsError);
  }

  // Now, perform sanity checks and extract metadata.
  const ModelMetadataExtractor* metadata_extractor =
      GetTfLiteEngine()->metadata_extractor();
  // Check that metadata is available.
  if (metadata_extractor->GetModelMetadata() == nullptr ||
      metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
    return CreateStatusWithPayload(StatusCode::kInvalidArgument,
                                   "Object detection models require TFLite "
                                   "Model Metadata but none was found",
                                   TfLiteSupportStatus::kMetadataNotFoundError);
  }
  // Check output tensor metadata is present and consistent with model.
  auto output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
  if (output_tensors_metadata == nullptr ||
      output_tensors_metadata->size() != 4) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrFormat(
            "Mismatch between number of output tensors (4) and output tensors "
            "metadata (%d).",
            output_tensors_metadata == nullptr
                ? 0
                : output_tensors_metadata->size()),
        TfLiteSupportStatus::kMetadataInconsistencyError);
  }

  output_indices_ = GetOutputIndices(output_tensors_metadata);

  // Extract mandatory BoundingBoxProperties for easier access at
  // post-processing time, performing sanity checks on the fly.
  ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties,
                   GetBoundingBoxProperties(
                       *output_tensors_metadata->Get(output_indices_[0])));
  if (bounding_box_properties->index() == nullptr) {
    bounding_box_corners_order_ = {0, 1, 2, 3};
  } else {
    auto bounding_box_index = bounding_box_properties->index();
    bounding_box_corners_order_ = {
        bounding_box_index->Get(0),
        bounding_box_index->Get(1),
        bounding_box_index->Get(2),
        bounding_box_index->Get(3),
    };
  }

  // Build label map (if available) from metadata.
  ASSIGN_OR_RETURN(
      label_map_,
      GetLabelMapIfAny(*metadata_extractor,
                       *output_tensors_metadata->Get(output_indices_[1]),
                       options_->display_names_locale()));

  // Set score threshold.
  if (options_->has_score_threshold()) {
    score_threshold_ = options_->score_threshold();
  } else {
    ASSIGN_OR_RETURN(
        score_threshold_,
        GetScoreThreshold(*metadata_extractor,
                          *output_tensors_metadata->Get(output_indices_[2])));
  }

  // Check tensor dimensions and batch size.
  for (int i = 0; i < 4; ++i) {
    std::size_t j = output_indices_[i];
    const TfLiteTensor* tensor = TfLiteEngine::GetOutput(interpreter, j);
    if (tensor->dims->size != kOutputTensorsExpectedDims[i]) {
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrFormat("Output tensor at index %d is expected to "
                          "have %d dimensions, found %d.",
                          j, kOutputTensorsExpectedDims[i], tensor->dims->size),
          TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
    }
    if (tensor->dims->data[0] != 1) {
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrFormat("Expected batch size of 1, found %d.",
                          tensor->dims->data[0]),
          TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
    }
  }

  return absl::OkStatus();
}