in tensorflow_decision_forests/tensorflow/ops/training/kernel.cc [132:223]
tf::Status FeatureSet::Link(
tf::OpKernelContext* ctx, const std::string& concat_feature_ids,
const std::string& label_id, const std::string& weight_id,
const dataset::proto::DataSpecification* const existing_dataspec,
const DatasetType dataset_type) {
std::vector<std::string> feature_ids;
if (!concat_feature_ids.empty()) {
feature_ids = absl::StrSplit(concat_feature_ids, ',');
std::sort(feature_ids.begin(), feature_ids.end());
}
if (!label_id.empty()) {
feature_ids.push_back(label_id);
}
if (!weight_id.empty()) {
feature_ids.push_back(weight_id);
}
for (const auto& feature_id : feature_ids) {
std::string resource_id = feature_id;
switch (dataset_type) {
case DatasetType::kTraining:
break;
case DatasetType::kValidation:
// See "_FEATURE_RESOURCE_VALIDATION_SUFFIX" in "core.py".
absl::StrAppend(&resource_id, "__VALIDATION");
break;
}
AbstractFeatureResource* feature;
TF_RETURN_IF_ERROR(
ctx->resource_manager()->Lookup<AbstractFeatureResource, true>(
kModelContainer, resource_id, &feature));
const int feature_idx =
existing_dataspec ? dataset::GetColumnIdxFromName(
feature->feature_name(), *existing_dataspec)
: NumFeatures();
if (feature_id == label_id) {
label_feature_idx_ = feature_idx;
label_feature_ = feature->feature_name();
} else if (feature_id == weight_id) {
weight_feature_ = feature->feature_name();
} else {
input_features_.push_back(feature->feature_name());
}
auto* numerical_feature =
dynamic_cast<SimpleMLNumericalFeature::Resource*>(feature);
auto* categorical_string_feature =
dynamic_cast<SimpleMLCategoricalStringFeature::Resource*>(feature);
auto* categorical_int_feature =
dynamic_cast<SimpleMLCategoricalIntFeature::Resource*>(feature);
auto* hash_feature = dynamic_cast<SimpleMLHashFeature::Resource*>(feature);
auto* categorical_set_string_feature =
dynamic_cast<SimpleMLCategoricalSetStringFeature::Resource*>(feature);
auto* categorical_set_int_feature =
dynamic_cast<SimpleMLCategoricalSetIntFeature::Resource*>(feature);
if (numerical_feature) {
numerical_features_.push_back({feature_idx, numerical_feature});
} else if (categorical_string_feature) {
categorical_string_features_.push_back(
{feature_idx, categorical_string_feature});
} else if (categorical_int_feature) {
categorical_int_features_.push_back(
{feature_idx, categorical_int_feature});
} else if (categorical_set_string_feature) {
categorical_set_string_features_.push_back(
{feature_idx, categorical_set_string_feature});
} else if (categorical_set_int_feature) {
categorical_set_int_features_.push_back(
{feature_idx, categorical_set_int_feature});
} else if (hash_feature) {
hash_features_.push_back({feature_idx, hash_feature});
} else {
return tf::Status(tf::error::Code::INVALID_ARGUMENT,
absl::StrCat("Unsupported type for feature \"",
feature->feature_name(), "\""));
}
}
if (!weight_id.empty() && weight_feature_.empty()) {
return tf::Status(tf::error::Code::INVALID_ARGUMENT,
absl::StrCat("Weight feature not found: ", weight_id));
}
return tf::Status::OK();
}