in tensorflow_lite_support/cc/task/vision/image_classifier.cc [151:287]
absl::Status ImageClassifier::CheckAndSetOutputs() {
num_outputs_ = TfLiteEngine::OutputCount(GetTfLiteEngine()->interpreter());
// Perform sanity checks and extract metadata.
const ModelMetadataExtractor* metadata_extractor =
GetTfLiteEngine()->metadata_extractor();
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata();
// Loop over output tensors metadata, if any.
// Note: models with no output tensor metadata at all are supported.
if (output_tensor_metadata != nullptr) {
int num_output_tensors = output_tensor_metadata->size();
if (num_outputs_ != num_output_tensors) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Mismatch between number of output tensors (%d) and "
"output tensors "
"metadata (%d).",
num_outputs_, num_output_tensors),
TfLiteSupportStatus::kMetadataInconsistencyError);
}
for (int i = 0; i < num_output_tensors; ++i) {
const tflite::TensorMetadata* output_tensor =
output_tensor_metadata->Get(i);
ASSIGN_OR_RETURN(
ClassificationHead head,
BuildClassificationHead(*metadata_extractor, *output_tensor,
options_->display_names_locale()));
classification_heads_.emplace_back(std::move(head));
}
}
// If classifier heads are not set, build default ones based on model
// introspection. This happens if a model with partial or no metadata was
// provided through the `model_file_with_metadata` options field.
if (classification_heads_.empty()) {
classification_heads_.reserve(num_outputs_);
for (int output_index = 0; output_index < num_outputs_; ++output_index) {
classification_heads_.emplace_back(ClassificationHead{});
}
}
if (num_outputs_ != classification_heads_.size()) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Got %d classifier head(s), expected %d according to "
"the label map.",
num_outputs_, classification_heads_.size()),
TfLiteSupportStatus::kMetadataInconsistencyError);
}
int num_quantized_outputs = 0;
for (int i = 0; i < num_outputs_; ++i) {
const TfLiteTensor* output_tensor =
TfLiteEngine::GetOutput(GetTfLiteEngine()->interpreter(), i);
const int num_dimensions = output_tensor->dims->size;
if (num_dimensions == 4) {
if (output_tensor->dims->data[1] != 1 ||
output_tensor->dims->data[2] != 1) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Unexpected WxH sizes for output index %d: got "
"%dx%d, expected 1x1.",
i, output_tensor->dims->data[2],
output_tensor->dims->data[1]),
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
}
} else if (num_dimensions != 2) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat(
"Unexpected number of dimensions for output index %d: got %dD, "
"expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, W=1, "
"H=1).",
i, num_dimensions),
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
}
if (output_tensor->dims->data[0] != 1) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("The output array is expected to have a batch size "
"of 1. Got %d for output index %d.",
output_tensor->dims->data[0], i),
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
}
int num_classes = output_tensor->dims->data[num_dimensions - 1];
// If label map is not set, build a default one based on model
// introspection. This happens if a model with partial or no metadata was
// provided through the `model_file_with_metadata` options field.
if (classification_heads_[i].label_map_items.empty()) {
classification_heads_[i].label_map_items.reserve(num_classes);
for (int class_index = 0; class_index < num_classes; ++class_index) {
classification_heads_[i].label_map_items.emplace_back(LabelMapItem{});
}
}
int num_label_map_items = classification_heads_[i].label_map_items.size();
if (num_classes != num_label_map_items) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Got %d class(es) for output index %d, expected %d "
"according to the label map.",
output_tensor->dims->data[num_dimensions - 1], i,
num_label_map_items),
TfLiteSupportStatus::kMetadataInconsistencyError);
}
if (output_tensor->type == kTfLiteUInt8) {
num_quantized_outputs++;
} else if (output_tensor->type != kTfLiteFloat32) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Type mismatch for output tensor %s. Requested one "
"of these types: "
"kTfLiteUint8/kTfLiteFloat32, got %s.",
output_tensor->name,
TfLiteTypeGetName(output_tensor->type)),
TfLiteSupportStatus::kInvalidOutputTensorTypeError);
}
}
if (num_quantized_outputs > 0 && num_quantized_outputs != num_outputs_) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrFormat("Got %d quantized output(s), expected %d (i.e. all "
"provided outputs must be quantized).",
num_quantized_outputs, num_outputs_),
TfLiteSupportStatus::kInvalidOutputTensorTypeError);
}
has_uint8_outputs_ = (num_quantized_outputs > 0);
return absl::OkStatus();
}