absl::Status ClassificationPostprocessor::Init()

in tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc [55:217]


absl::Status ClassificationPostprocessor::Init(
    std::unique_ptr<ClassificationOptions> options) {
  // Sanity check options
  if (options->max_results() == 0) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        "Invalid `max_results` option: value must be != 0",
        TfLiteSupportStatus::kInvalidArgumentError);
  }
  if (options->class_name_allowlist_size() > 0 &&
      options->class_name_denylist_size() > 0) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        "`class_name_allowlist` and `class_name_denylist` are mutually "
        "exclusive options.",
        TfLiteSupportStatus::kInvalidArgumentError);
  }

  ASSIGN_OR_RETURN(classification_head_,
                   BuildClassificationHead(*engine_->metadata_extractor(),
                                           *GetTensorMetadata(),
                                           options->display_names_locale()));

  // Sanity check output tensors
  const TfLiteTensor* output_tensor = GetTensor();
  const int num_dimensions = output_tensor->dims->size;
  if (num_dimensions == 4) {
    if (output_tensor->dims->data[1] != 1 ||
        output_tensor->dims->data[2] != 1) {
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrFormat("Unexpected WxH sizes for output index %d: got "
                          "%dx%d, expected 1x1.",
                          tensor_indices_.at(0), output_tensor->dims->data[2],
                          output_tensor->dims->data[1]),
          TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
    }
  } else if (num_dimensions != 2) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrFormat(
            "Unexpected number of dimensions for output index %d: got %dD, "
            "expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, W=1, "
            "H=1).",
            tensor_indices_.at(0), num_dimensions),
        TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
  }
  if (output_tensor->dims->data[0] != 1) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrFormat("The output array is expected to have a batch size "
                        "of 1. Got %d for output index %d.",
                        output_tensor->dims->data[0], tensor_indices_.at(0)),
        TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
  }
  int num_classes = output_tensor->dims->data[num_dimensions - 1];
  // If label map is not set, build a default one based on model
  // introspection. This happens if a model with partial or no metadata was
  // provided through the `model_file_with_metadata` options field.
  if (classification_head_.label_map_items.empty()) {
    classification_head_.label_map_items.reserve(num_classes);
    for (int class_index = 0; class_index < num_classes; ++class_index) {
      classification_head_.label_map_items.emplace_back(LabelMapItem{});
    }
  }
  int num_label_map_items = classification_head_.label_map_items.size();
  if (num_classes != num_label_map_items) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrFormat("Got %d class(es) for output index %d, expected %d "
                        "according to the label map.",
                        output_tensor->dims->data[num_dimensions - 1],
                        tensor_indices_.at(0), num_label_map_items),
        TfLiteSupportStatus::kMetadataInconsistencyError);
  }
  if (output_tensor->type != kTfLiteUInt8 &&
      output_tensor->type != kTfLiteFloat32) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrFormat("Type mismatch for output tensor %s. Requested one "
                        "of these types: "
                        "kTfLiteUint8/kTfLiteFloat32, got %s.",
                        output_tensor->name,
                        TfLiteTypeGetName(output_tensor->type)),
        TfLiteSupportStatus::kInvalidOutputTensorTypeError);
  }

  // Set class name set
  if (options->class_name_denylist_size() != 0 ||
      options->class_name_allowlist_size() != 0) {
    // Before processing class names allowlist or denylist from the input
    // options create a set with _all_ known class names from the label map(s).
    absl::flat_hash_set<std::string> head_class_names;
    for (const auto& item : classification_head_.label_map_items) {
      if (!item.name.empty()) {
        head_class_names.insert(item.name);
      }
    }

    if (head_class_names.empty()) {
      std::string name = classification_head_.name;
      if (name.empty()) {
        name = absl::StrFormat("#%d", tensor_indices_.at(0));
      }
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrFormat(
              "Using `class_name_allowlist` or `class_name_denylist` "
              "requires labels to be present but none was found for "
              "classification head: %s",
              name),
          TfLiteSupportStatus::kMetadataMissingLabelsError);
    }

    class_name_set_.is_allowlist = options->class_name_allowlist_size() > 0;
    const auto& class_names = class_name_set_.is_allowlist
                                  ? options->class_name_allowlist()
                                  : options->class_name_denylist();

    // Note: duplicate or unknown classes are just ignored.
    class_name_set_.values.clear();
    for (const auto& class_name : class_names) {
      if (!head_class_names.contains(class_name)) {
        continue;
      }
      class_name_set_.values.insert(class_name);
    }

    if (class_name_set_.values.empty()) {
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrFormat(
              "Invalid class names specified via `class_name_%s`: none match "
              "with model labels.",
              class_name_set_.is_allowlist ? "allowlsit" : "denylist"),
          TfLiteSupportStatus::kInvalidArgumentError);
    }
  }

  // Set score calibration
  if (classification_head_.calibration_params.has_value()) {
    score_calibration_ = absl::make_unique<ScoreCalibration>();
    if (score_calibration_ == nullptr) {
      return CreateStatusWithPayload(
          StatusCode::kInternal, "Could not create score calibration object.");
    }

    RETURN_IF_ERROR(score_calibration_->InitializeFromParameters(
        classification_head_.calibration_params.value()));
  }

  num_results_ =
      options->max_results() >= 0
          ? std::min(
                static_cast<int>(classification_head_.label_map_items.size()),
                options->max_results())
          : classification_head_.label_map_items.size();
  score_threshold_ = options->has_score_threshold()
                         ? options->score_threshold()
                         : classification_head_.score_threshold;

  return absl::OkStatus();
}