in tensorflow_lite_support/cc/task/vision/object_detector.cc [444:542]
absl::Status ObjectDetector::CheckAndSetOutputs() {
// First, sanity checks on the model itself.
const TfLiteEngine::Interpreter* interpreter =
GetTfLiteEngine()->interpreter();
// Check the number of output tensors.
if (TfLiteEngine::OutputCount(interpreter) != 4) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Mobile SSD models are expected to have exactly 4 "
"outputs, found %d",
TfLiteEngine::OutputCount(interpreter)),
TfLiteSupportStatus::kInvalidNumOutputTensorsError);
}
// Now, perform sanity checks and extract metadata.
const ModelMetadataExtractor* metadata_extractor =
GetTfLiteEngine()->metadata_extractor();
// Check that metadata is available.
if (metadata_extractor->GetModelMetadata() == nullptr ||
metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
return CreateStatusWithPayload(StatusCode::kInvalidArgument,
"Object detection models require TFLite "
"Model Metadata but none was found",
TfLiteSupportStatus::kMetadataNotFoundError);
}
// Check output tensor metadata is present and consistent with model.
auto output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr ||
output_tensors_metadata->size() != 4) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Mismatch between number of output tensors (4) and output tensors "
"metadata (%d).",
output_tensors_metadata == nullptr
? 0
: output_tensors_metadata->size()),
TfLiteSupportStatus::kMetadataInconsistencyError);
}
output_indices_ = GetOutputIndices(output_tensors_metadata);
// Extract mandatory BoundingBoxProperties for easier access at
// post-processing time, performing sanity checks on the fly.
ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties,
GetBoundingBoxProperties(
*output_tensors_metadata->Get(output_indices_[0])));
if (bounding_box_properties->index() == nullptr) {
bounding_box_corners_order_ = {0, 1, 2, 3};
} else {
auto bounding_box_index = bounding_box_properties->index();
bounding_box_corners_order_ = {
bounding_box_index->Get(0),
bounding_box_index->Get(1),
bounding_box_index->Get(2),
bounding_box_index->Get(3),
};
}
// Build label map (if available) from metadata.
ASSIGN_OR_RETURN(
label_map_,
GetLabelMapIfAny(*metadata_extractor,
*output_tensors_metadata->Get(output_indices_[1]),
options_->display_names_locale()));
// Set score threshold.
if (options_->has_score_threshold()) {
score_threshold_ = options_->score_threshold();
} else {
ASSIGN_OR_RETURN(
score_threshold_,
GetScoreThreshold(*metadata_extractor,
*output_tensors_metadata->Get(output_indices_[2])));
}
// Check tensor dimensions and batch size.
for (int i = 0; i < 4; ++i) {
std::size_t j = output_indices_[i];
const TfLiteTensor* tensor = TfLiteEngine::GetOutput(interpreter, j);
if (tensor->dims->size != kOutputTensorsExpectedDims[i]) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Output tensor at index %d is expected to "
"have %d dimensions, found %d.",
j, kOutputTensorsExpectedDims[i], tensor->dims->size),
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
}
if (tensor->dims->data[0] != 1) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Expected batch size of 1, found %d.",
tensor->dims->data[0]),
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
}
}
return absl::OkStatus();
}