absl::Status ClassificationPostprocessor::Postprocess()

in tensorflow_lite_support/cc/task/processor/classification_postprocessor.h [101:194]


absl::Status ClassificationPostprocessor::Postprocess(T* classifications) {
  const auto& head = classification_head_;
  classifications->set_head_index(tensor_indices_.at(0));

  std::vector<std::pair<int, float>> score_pairs;
  score_pairs.reserve(head.label_map_items.size());

  const TfLiteTensor* output_tensor = GetTensor();
  if (output_tensor->type == kTfLiteUInt8) {
    ASSIGN_OR_RETURN(const uint8* output_data,
                     core::AssertAndReturnTypedTensor<uint8>(output_tensor));
    for (int j = 0; j < head.label_map_items.size(); ++j) {
      score_pairs.emplace_back(
          j, output_tensor->params.scale * (static_cast<int>(output_data[j]) -
                                            output_tensor->params.zero_point));
    }
  } else {
    ASSIGN_OR_RETURN(const float* output_data,
                     core::AssertAndReturnTypedTensor<float>(output_tensor));
    for (int j = 0; j < head.label_map_items.size(); ++j) {
      score_pairs.emplace_back(j, output_data[j]);
    }
  }

  // Optional score calibration.
  if (score_calibration_ != nullptr) {
    for (auto& score_pair : score_pairs) {
      const std::string& class_name =
          head.label_map_items[score_pair.first].name;

      // In ComputeCalibratedScore, score_pair.second is set to the
      // default_score value from metadata [1] if the category (1) has no
      // score calibration data or (2) has a very low confident uncalibrated
      // score, i.e. lower than the `min_uncalibrated_score` threshold.
      // Otherwise, score_pair.second is calculated based on the selected
      // score transformation function, and the value is guaranteed to be in
      // the range of [0, scale], where scale is a label-dependent sigmoid
      // parameter.
      //
      // [1]:
      // https://github.com/tensorflow/tflite-support/blob/af26cb6952ccdeee0e849df2b93dbe7e57f6bc48/tensorflow_lite_support/metadata/metadata_schema.fbs#L453
      score_pair.second = score_calibration_->ComputeCalibratedScore(
          class_name, score_pair.second);
    }
  }

  if (class_name_set_.values.empty()) {
    // Partially sort in descending order (higher score is better).
    absl::c_partial_sort(
        score_pairs, score_pairs.begin() + num_results_,
        [](const std::pair<int, float>& a, const std::pair<int, float>& b) {
          return a.second > b.second;
        });

    for (int j = 0; j < num_results_; ++j) {
      float score = score_pairs[j].second;
      if (score < score_threshold_) {
        break;
      }
      auto* cl = classifications->add_classes();
      cl->set_index(score_pairs[j].first);
      cl->set_score(score);
    }
  } else {
    // Sort in descending order (higher score is better).
    absl::c_sort(score_pairs, [](const std::pair<int, float>& a,
                                 const std::pair<int, float>& b) {
      return a.second > b.second;
    });

    for (int j = 0; j < head.label_map_items.size(); ++j) {
      float score = score_pairs[j].second;
      if (score < score_threshold_ ||
          classifications->classes_size() >= num_results_) {
        break;
      }

      const int class_index = score_pairs[j].first;
      const std::string& class_name = head.label_map_items[class_index].name;

      bool class_name_found = class_name_set_.values.contains(class_name);

      if ((!class_name_found && class_name_set_.is_allowlist) ||
          (class_name_found && !class_name_set_.is_allowlist)) {
        continue;
      }

      auto* cl = classifications->add_classes();
      cl->set_index(class_index);
      cl->set_score(score);
    }
  }
  return FillResultsFromLabelMaps(classifications);
}