StatusOr ImageClassifier::Postprocess()

in tensorflow_lite_support/cc/task/vision/image_classifier.cc [387:507]


StatusOr<ClassificationResult> ImageClassifier::Postprocess(
    const std::vector<const TfLiteTensor*>& output_tensors,
    const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
  if (output_tensors.size() != num_outputs_) {
    return CreateStatusWithPayload(
        StatusCode::kInternal,
        absl::StrFormat("Expected %d output tensors, found %d", num_outputs_,
                        output_tensors.size()));
  }

  ClassificationResult result;
  std::vector<std::pair<int, float>> score_pairs;

  for (int i = 0; i < num_outputs_; ++i) {
    auto* classifications = result.add_classifications();
    classifications->set_head_index(i);

    const auto& head = classification_heads_[i];
    score_pairs.clear();
    score_pairs.reserve(head.label_map_items.size());

    const TfLiteTensor* output_tensor = output_tensors[i];
    if (has_uint8_outputs_) {
      ASSIGN_OR_RETURN(const uint8* output_data,
                       AssertAndReturnTypedTensor<uint8>(output_tensor));
      for (int j = 0; j < head.label_map_items.size(); ++j) {
        score_pairs.emplace_back(j, output_tensor->params.scale *
                                        (static_cast<int>(output_data[j]) -
                                         output_tensor->params.zero_point));
      }
    } else {
      ASSIGN_OR_RETURN(const float* output_data,
                       AssertAndReturnTypedTensor<float>(output_tensor));
      for (int j = 0; j < head.label_map_items.size(); ++j) {
        score_pairs.emplace_back(j, output_data[j]);
      }
    }

    // Optional score calibration.
    if (score_calibrations_[i] != nullptr) {
      for (auto& score_pair : score_pairs) {
        const std::string& class_name =
            head.label_map_items[score_pair.first].name;

        // In ComputeCalibratedScore, score_pair.second is set to the
        // default_score value from metadata [1] if the category (1) has no
        // score calibration data or (2) has a very low confident uncalibrated
        // score, i.e. lower than the `min_uncalibrated_score` threshold.
        // Otherwise, score_pair.second is calculated based on the selected
        // score transformation function, and the value is guaranteed to be in
        // the range of [0, scale], where scale is a label-dependent sigmoid
        // parameter.
        //
        // [1]:
        // https://github.com/tensorflow/tflite-support/blob/af26cb6952ccdeee0e849df2b93dbe7e57f6bc48/tensorflow_lite_support/metadata/metadata_schema.fbs#L453
        score_pair.second = score_calibrations_[i]->ComputeCalibratedScore(
            class_name, score_pair.second);
      }
    }

    int num_results =
        options_->max_results() >= 0
            ? std::min(static_cast<int>(head.label_map_items.size()),
                       options_->max_results())
            : head.label_map_items.size();
    float score_threshold = options_->has_score_threshold()
                                ? options_->score_threshold()
                                : head.score_threshold;

    if (class_name_set_.values.empty()) {
      // Partially sort in descending order (higher score is better).
      absl::c_partial_sort(
          score_pairs, score_pairs.begin() + num_results,
          [](const std::pair<int, float>& a, const std::pair<int, float>& b) {
            return a.second > b.second;
          });

      for (int j = 0; j < num_results; ++j) {
        float score = score_pairs[j].second;
        if (score < score_threshold) {
          break;
        }
        auto* cl = classifications->add_classes();
        cl->set_index(score_pairs[j].first);
        cl->set_score(score);
      }
    } else {
      // Sort in descending order (higher score is better).
      absl::c_sort(score_pairs, [](const std::pair<int, float>& a,
                                   const std::pair<int, float>& b) {
        return a.second > b.second;
      });

      for (int j = 0; j < head.label_map_items.size(); ++j) {
        float score = score_pairs[j].second;
        if (score < score_threshold ||
            classifications->classes_size() >= num_results) {
          break;
        }

        const int class_index = score_pairs[j].first;
        const std::string& class_name = head.label_map_items[class_index].name;

        bool class_name_found = class_name_set_.values.contains(class_name);

        if ((!class_name_found && class_name_set_.is_whitelist) ||
            (class_name_found && !class_name_set_.is_whitelist)) {
          continue;
        }

        auto* cl = classifications->add_classes();
        cl->set_index(class_index);
        cl->set_score(score);
      }
    }
  }

  RETURN_IF_ERROR(FillResultsFromLabelMaps(&result));

  return result;
}