in research/carls/base/input_context_helper.cc [556:614]
absl::Status Prune(const InputContext& input, const int max_values_per_feature,
InputContext* input_context) {
RET_CHECK_TRUE(max_values_per_feature > 0);
static auto cmp = [](const FeatureValue& lhs,
const FeatureValue& rhs) -> bool {
float lhs_weight = 0;
float rhs_weight = 0;
// Bytes.
if (lhs.has_bytes_feature() && !lhs.bytes_feature().weight().empty()) {
lhs_weight = lhs.bytes_feature().weight(0);
}
if (rhs.has_bytes_feature() && !rhs.bytes_feature().weight().empty()) {
rhs_weight = rhs.bytes_feature().weight(0);
}
// Float.
if (lhs.has_float_feature() && !lhs.float_feature().weight().empty()) {
lhs_weight = lhs.float_feature().weight(0);
}
if (rhs.has_float_feature() && !rhs.float_feature().weight().empty()) {
rhs_weight = rhs.float_feature().weight(0);
}
// Int64.
if (lhs.has_int64_feature() && !lhs.int64_feature().weight().empty()) {
lhs_weight = lhs.int64_feature().weight(0);
}
if (rhs.has_int64_feature() && !rhs.int64_feature().weight().empty()) {
rhs_weight = rhs.int64_feature().weight(0);
}
// Uint64.
if (lhs.has_uint64_feature() && !lhs.uint64_feature().weight().empty()) {
lhs_weight = lhs.uint64_feature().weight(0);
}
if (rhs.has_uint64_feature() && !rhs.uint64_feature().weight().empty()) {
rhs_weight = rhs.uint64_feature().weight(0);
}
return lhs_weight > rhs_weight;
};
input_context->Clear();
for (const auto& pair : input.feature()) {
const auto& name = pair.first;
const auto& input_feature = pair.second;
if (input_feature.feature_value_size() <= max_values_per_feature) {
*(*input_context->mutable_feature())[name].mutable_feature_value() =
input_feature.feature_value();
continue;
}
TopN<FeatureValue, decltype(cmp)> topn_result(max_values_per_feature, cmp);
for (const auto& feature_value : input_feature.feature_value()) {
topn_result.push(feature_value);
}
std::vector<FeatureValue> topn_results = std::move(*topn_result.Extract());
for (auto& feature_value : topn_results) {
*(*input_context->mutable_feature())[name].add_feature_value() =
std::move(feature_value);
}
}
return absl::OkStatus();
}