absl::Status ImageClassifier::CheckAndSetClassNameSet()

in tensorflow_lite_support/cc/task/vision/image_classifier.cc [289:350]


absl::Status ImageClassifier::CheckAndSetClassNameSet() {
  // Exit early if no blacklist/whitelist.
  if (options_->class_name_blacklist_size() == 0 &&
      options_->class_name_whitelist_size() == 0) {
    return absl::OkStatus();
  }

  // Before processing class names whitelist or blacklist from the input options
  // create a set with _all_ known class names from the label map(s).
  absl::flat_hash_set<std::string> all_class_names;
  int head_index = 0;
  for (const auto& head : classification_heads_) {
    absl::flat_hash_set<std::string> head_class_names;
    for (const auto& item : head.label_map_items) {
      if (!item.name.empty()) {
        head_class_names.insert(item.name);
      }
    }
    if (head_class_names.empty()) {
      std::string name = head.name;
      if (name.empty()) {
        name = absl::StrFormat("#%d", head_index);
      }
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrFormat(
              "Using `class_name_whitelist` or `class_name_blacklist` "
              "requires labels to be present but none was found for "
              "classification head: %s",
              name),
          TfLiteSupportStatus::kMetadataMissingLabelsError);
    }
    all_class_names.insert(head_class_names.begin(), head_class_names.end());
    head_index++;
  }

  class_name_set_.is_whitelist = options_->class_name_whitelist_size() > 0;
  const auto& class_names = class_name_set_.is_whitelist
                                ? options_->class_name_whitelist()
                                : options_->class_name_blacklist();

  // Note: duplicate or unknown classes are just ignored.
  class_name_set_.values.clear();
  for (const auto& class_name : class_names) {
    if (!all_class_names.contains(class_name)) {
      continue;
    }
    class_name_set_.values.insert(class_name);
  }

  if (class_name_set_.values.empty()) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrFormat(
            "Invalid class names specified via `class_name_%s`: none match "
            "with model labels.",
            class_name_set_.is_whitelist ? "whitelist" : "blacklist"),
        TfLiteSupportStatus::kInvalidArgumentError);
  }

  return absl::OkStatus();
}