tf::Status FeatureSet::Link()

in tensorflow_decision_forests/tensorflow/ops/training/kernel.cc [132:223]


tf::Status FeatureSet::Link(
    tf::OpKernelContext* ctx, const std::string& concat_feature_ids,
    const std::string& label_id, const std::string& weight_id,
    const dataset::proto::DataSpecification* const existing_dataspec,
    const DatasetType dataset_type) {
  std::vector<std::string> feature_ids;

  if (!concat_feature_ids.empty()) {
    feature_ids = absl::StrSplit(concat_feature_ids, ',');
    std::sort(feature_ids.begin(), feature_ids.end());
  }

  if (!label_id.empty()) {
    feature_ids.push_back(label_id);
  }

  if (!weight_id.empty()) {
    feature_ids.push_back(weight_id);
  }

  for (const auto& feature_id : feature_ids) {
    std::string resource_id = feature_id;
    switch (dataset_type) {
      case DatasetType::kTraining:
        break;
      case DatasetType::kValidation:
        // See "_FEATURE_RESOURCE_VALIDATION_SUFFIX" in "core.py".
        absl::StrAppend(&resource_id, "__VALIDATION");
        break;
    }

    AbstractFeatureResource* feature;
    TF_RETURN_IF_ERROR(
        ctx->resource_manager()->Lookup<AbstractFeatureResource, true>(
            kModelContainer, resource_id, &feature));

    const int feature_idx =
        existing_dataspec ? dataset::GetColumnIdxFromName(
                                feature->feature_name(), *existing_dataspec)
                          : NumFeatures();

    if (feature_id == label_id) {
      label_feature_idx_ = feature_idx;
      label_feature_ = feature->feature_name();
    } else if (feature_id == weight_id) {
      weight_feature_ = feature->feature_name();
    } else {
      input_features_.push_back(feature->feature_name());
    }

    auto* numerical_feature =
        dynamic_cast<SimpleMLNumericalFeature::Resource*>(feature);
    auto* categorical_string_feature =
        dynamic_cast<SimpleMLCategoricalStringFeature::Resource*>(feature);
    auto* categorical_int_feature =
        dynamic_cast<SimpleMLCategoricalIntFeature::Resource*>(feature);
    auto* hash_feature = dynamic_cast<SimpleMLHashFeature::Resource*>(feature);
    auto* categorical_set_string_feature =
        dynamic_cast<SimpleMLCategoricalSetStringFeature::Resource*>(feature);
    auto* categorical_set_int_feature =
        dynamic_cast<SimpleMLCategoricalSetIntFeature::Resource*>(feature);

    if (numerical_feature) {
      numerical_features_.push_back({feature_idx, numerical_feature});
    } else if (categorical_string_feature) {
      categorical_string_features_.push_back(
          {feature_idx, categorical_string_feature});
    } else if (categorical_int_feature) {
      categorical_int_features_.push_back(
          {feature_idx, categorical_int_feature});
    } else if (categorical_set_string_feature) {
      categorical_set_string_features_.push_back(
          {feature_idx, categorical_set_string_feature});
    } else if (categorical_set_int_feature) {
      categorical_set_int_features_.push_back(
          {feature_idx, categorical_set_int_feature});
    } else if (hash_feature) {
      hash_features_.push_back({feature_idx, hash_feature});
    } else {
      return tf::Status(tf::error::Code::INVALID_ARGUMENT,
                        absl::StrCat("Unsupported type for feature \"",
                                     feature->feature_name(), "\""));
    }
  }

  if (!weight_id.empty() && weight_feature_.empty()) {
    return tf::Status(tf::error::Code::INVALID_ARGUMENT,
                      absl::StrCat("Weight feature not found: ", weight_id));
  }

  return tf::Status::OK();
}