in tensorflow_lite_support/cc/task/processor/classification_postprocessor.h [101:194]
absl::Status ClassificationPostprocessor::Postprocess(T* classifications) {
const auto& head = classification_head_;
classifications->set_head_index(tensor_indices_.at(0));
std::vector<std::pair<int, float>> score_pairs;
score_pairs.reserve(head.label_map_items.size());
const TfLiteTensor* output_tensor = GetTensor();
if (output_tensor->type == kTfLiteUInt8) {
ASSIGN_OR_RETURN(const uint8* output_data,
core::AssertAndReturnTypedTensor<uint8>(output_tensor));
for (int j = 0; j < head.label_map_items.size(); ++j) {
score_pairs.emplace_back(
j, output_tensor->params.scale * (static_cast<int>(output_data[j]) -
output_tensor->params.zero_point));
}
} else {
ASSIGN_OR_RETURN(const float* output_data,
core::AssertAndReturnTypedTensor<float>(output_tensor));
for (int j = 0; j < head.label_map_items.size(); ++j) {
score_pairs.emplace_back(j, output_data[j]);
}
}
// Optional score calibration.
if (score_calibration_ != nullptr) {
for (auto& score_pair : score_pairs) {
const std::string& class_name =
head.label_map_items[score_pair.first].name;
// In ComputeCalibratedScore, score_pair.second is set to the
// default_score value from metadata [1] if the category (1) has no
// score calibration data or (2) has a very low confident uncalibrated
// score, i.e. lower than the `min_uncalibrated_score` threshold.
// Otherwise, score_pair.second is calculated based on the selected
// score transformation function, and the value is guaranteed to be in
// the range of [0, scale], where scale is a label-dependent sigmoid
// parameter.
//
// [1]:
// https://github.com/tensorflow/tflite-support/blob/af26cb6952ccdeee0e849df2b93dbe7e57f6bc48/tensorflow_lite_support/metadata/metadata_schema.fbs#L453
score_pair.second = score_calibration_->ComputeCalibratedScore(
class_name, score_pair.second);
}
}
if (class_name_set_.values.empty()) {
// Partially sort in descending order (higher score is better).
absl::c_partial_sort(
score_pairs, score_pairs.begin() + num_results_,
[](const std::pair<int, float>& a, const std::pair<int, float>& b) {
return a.second > b.second;
});
for (int j = 0; j < num_results_; ++j) {
float score = score_pairs[j].second;
if (score < score_threshold_) {
break;
}
auto* cl = classifications->add_classes();
cl->set_index(score_pairs[j].first);
cl->set_score(score);
}
} else {
// Sort in descending order (higher score is better).
absl::c_sort(score_pairs, [](const std::pair<int, float>& a,
const std::pair<int, float>& b) {
return a.second > b.second;
});
for (int j = 0; j < head.label_map_items.size(); ++j) {
float score = score_pairs[j].second;
if (score < score_threshold_ ||
classifications->classes_size() >= num_results_) {
break;
}
const int class_index = score_pairs[j].first;
const std::string& class_name = head.label_map_items[class_index].name;
bool class_name_found = class_name_set_.values.contains(class_name);
if ((!class_name_found && class_name_set_.is_allowlist) ||
(class_name_found && !class_name_set_.is_allowlist)) {
continue;
}
auto* cl = classifications->add_classes();
cl->set_index(class_index);
cl->set_score(score);
}
}
return FillResultsFromLabelMaps(classifications);
}