absl::Status Prune()

in research/carls/base/input_context_helper.cc [556:614]


absl::Status Prune(const InputContext& input, const int max_values_per_feature,
                   InputContext* input_context) {
  RET_CHECK_TRUE(max_values_per_feature > 0);
  static auto cmp = [](const FeatureValue& lhs,
                       const FeatureValue& rhs) -> bool {
    float lhs_weight = 0;
    float rhs_weight = 0;
    // Bytes.
    if (lhs.has_bytes_feature() && !lhs.bytes_feature().weight().empty()) {
      lhs_weight = lhs.bytes_feature().weight(0);
    }
    if (rhs.has_bytes_feature() && !rhs.bytes_feature().weight().empty()) {
      rhs_weight = rhs.bytes_feature().weight(0);
    }
    // Float.
    if (lhs.has_float_feature() && !lhs.float_feature().weight().empty()) {
      lhs_weight = lhs.float_feature().weight(0);
    }
    if (rhs.has_float_feature() && !rhs.float_feature().weight().empty()) {
      rhs_weight = rhs.float_feature().weight(0);
    }
    // Int64.
    if (lhs.has_int64_feature() && !lhs.int64_feature().weight().empty()) {
      lhs_weight = lhs.int64_feature().weight(0);
    }
    if (rhs.has_int64_feature() && !rhs.int64_feature().weight().empty()) {
      rhs_weight = rhs.int64_feature().weight(0);
    }
    // Uint64.
    if (lhs.has_uint64_feature() && !lhs.uint64_feature().weight().empty()) {
      lhs_weight = lhs.uint64_feature().weight(0);
    }
    if (rhs.has_uint64_feature() && !rhs.uint64_feature().weight().empty()) {
      rhs_weight = rhs.uint64_feature().weight(0);
    }

    return lhs_weight > rhs_weight;
  };
  input_context->Clear();
  for (const auto& pair : input.feature()) {
    const auto& name = pair.first;
    const auto& input_feature = pair.second;
    if (input_feature.feature_value_size() <= max_values_per_feature) {
      *(*input_context->mutable_feature())[name].mutable_feature_value() =
          input_feature.feature_value();
      continue;
    }
    TopN<FeatureValue, decltype(cmp)> topn_result(max_values_per_feature, cmp);
    for (const auto& feature_value : input_feature.feature_value()) {
      topn_result.push(feature_value);
    }
    std::vector<FeatureValue> topn_results = std::move(*topn_result.Extract());
    for (auto& feature_value : topn_results) {
      *(*input_context->mutable_feature())[name].add_feature_value() =
          std::move(feature_value);
    }
  }
  return absl::OkStatus();
}