absl::Status ImageClassifier::CheckAndSetOutputs()

in tensorflow_lite_support/cc/task/vision/image_classifier.cc [151:287]


absl::Status ImageClassifier::CheckAndSetOutputs() {
  num_outputs_ = TfLiteEngine::OutputCount(GetTfLiteEngine()->interpreter());

  // Perform sanity checks and extract metadata.
  const ModelMetadataExtractor* metadata_extractor =
      GetTfLiteEngine()->metadata_extractor();

  const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
      output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata();

  // Loop over output tensors metadata, if any.
  // Note: models with no output tensor metadata at all are supported.
  if (output_tensor_metadata != nullptr) {
    int num_output_tensors = output_tensor_metadata->size();

    if (num_outputs_ != num_output_tensors) {
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrFormat("Mismatch between number of output tensors (%d) and "
                          "output tensors "
                          "metadata (%d).",
                          num_outputs_, num_output_tensors),
          TfLiteSupportStatus::kMetadataInconsistencyError);
    }

    for (int i = 0; i < num_output_tensors; ++i) {
      const tflite::TensorMetadata* output_tensor =
          output_tensor_metadata->Get(i);

      ASSIGN_OR_RETURN(
          ClassificationHead head,
          BuildClassificationHead(*metadata_extractor, *output_tensor,
                                  options_->display_names_locale()));

      classification_heads_.emplace_back(std::move(head));
    }
  }

  // If classifier heads are not set, build default ones based on model
  // introspection. This happens if a model with partial or no metadata was
  // provided through the `model_file_with_metadata` options field.
  if (classification_heads_.empty()) {
    classification_heads_.reserve(num_outputs_);
    for (int output_index = 0; output_index < num_outputs_; ++output_index) {
      classification_heads_.emplace_back(ClassificationHead{});
    }
  }

  if (num_outputs_ != classification_heads_.size()) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrFormat("Got %d classifier head(s), expected %d according to "
                        "the label map.",
                        num_outputs_, classification_heads_.size()),
        TfLiteSupportStatus::kMetadataInconsistencyError);
  }

  int num_quantized_outputs = 0;
  for (int i = 0; i < num_outputs_; ++i) {
    const TfLiteTensor* output_tensor =
        TfLiteEngine::GetOutput(GetTfLiteEngine()->interpreter(), i);
    const int num_dimensions = output_tensor->dims->size;
    if (num_dimensions == 4) {
      if (output_tensor->dims->data[1] != 1 ||
          output_tensor->dims->data[2] != 1) {
        return CreateStatusWithPayload(
            StatusCode::kInvalidArgument,
            absl::StrFormat("Unexpected WxH sizes for output index %d: got "
                            "%dx%d, expected 1x1.",
                            i, output_tensor->dims->data[2],
                            output_tensor->dims->data[1]),
            TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
      }
    } else if (num_dimensions != 2) {
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrFormat(
              "Unexpected number of dimensions for output index %d: got %dD, "
              "expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, W=1, "
              "H=1).",
              i, num_dimensions),
          TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
    }
    if (output_tensor->dims->data[0] != 1) {
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrFormat("The output array is expected to have a batch size "
                          "of 1. Got %d for output index %d.",
                          output_tensor->dims->data[0], i),
          TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
    }
    int num_classes = output_tensor->dims->data[num_dimensions - 1];
    // If label map is not set, build a default one based on model
    // introspection. This happens if a model with partial or no metadata was
    // provided through the `model_file_with_metadata` options field.
    if (classification_heads_[i].label_map_items.empty()) {
      classification_heads_[i].label_map_items.reserve(num_classes);
      for (int class_index = 0; class_index < num_classes; ++class_index) {
        classification_heads_[i].label_map_items.emplace_back(LabelMapItem{});
      }
    }
    int num_label_map_items = classification_heads_[i].label_map_items.size();
    if (num_classes != num_label_map_items) {
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrFormat("Got %d class(es) for output index %d, expected %d "
                          "according to the label map.",
                          output_tensor->dims->data[num_dimensions - 1], i,
                          num_label_map_items),
          TfLiteSupportStatus::kMetadataInconsistencyError);
    }
    if (output_tensor->type == kTfLiteUInt8) {
      num_quantized_outputs++;
    } else if (output_tensor->type != kTfLiteFloat32) {
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrFormat("Type mismatch for output tensor %s. Requested one "
                          "of these types: "
                          "kTfLiteUint8/kTfLiteFloat32, got %s.",
                          output_tensor->name,
                          TfLiteTypeGetName(output_tensor->type)),
          TfLiteSupportStatus::kInvalidOutputTensorTypeError);
    }
  }

  if (num_quantized_outputs > 0 && num_quantized_outputs != num_outputs_) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrFormat("Got %d quantized output(s), expected %d (i.e. all "
                        "provided outputs must be quantized).",
                        num_quantized_outputs, num_outputs_),
        TfLiteSupportStatus::kInvalidOutputTensorTypeError);
  }
  has_uint8_outputs_ = (num_quantized_outputs > 0);

  return absl::OkStatus();
}