in tensorflow_decision_forests/tensorflow/ops/training/kernel.cc [306:480]
tf::Status FeatureSet::InitializeDatasetFromFeatures(
tf::OpKernelContext* ctx,
const dataset::proto::DataSpecificationGuide& guide,
dataset::VerticalDataset* dataset) {
int64_t num_batches = -1;
int64_t num_examples = -1;
const auto set_num_examples =
[&num_examples, &num_batches](
const int64_t observed_num_examples,
const int64_t observed_num_batches) -> tf::Status {
if (num_examples == -1) {
num_examples = observed_num_examples;
num_batches = observed_num_batches;
return tf::Status::OK();
}
if (num_examples != observed_num_examples) {
return tf::Status(
tf::error::Code::INVALID_ARGUMENT,
absl::Substitute("Inconsistent number of training examples for the "
"different input features $0 != $1.",
num_examples, observed_num_examples));
}
return tf::Status::OK();
};
for (int feature_idx = 0; feature_idx < NumFeatures(); feature_idx++) {
dataset->mutable_data_spec()->add_columns();
}
// Apply the guide on a column. The type of the column should be set.
const auto apply_guide = [&](const absl::string_view feature_name,
dataset::proto::Column* col) -> tf::Status {
dataset::proto::ColumnGuide col_guide;
dataset::BuildColumnGuide(feature_name, guide, &col_guide);
return utils::FromUtilStatus(
dataset::UpdateSingleColSpecWithGuideInfo(col_guide, col));
};
TF_RETURN_IF_ERROR(IterateFeatures(
[&](SimpleMLNumericalFeature::Resource* feature, const int feature_idx) {
auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
col->set_name(feature->feature_name());
col->set_type(dataset::proto::ColumnType::NUMERICAL);
TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
return set_num_examples(feature->data().size(), feature->NumBatches());
},
[&](SimpleMLCategoricalStringFeature::Resource* feature,
const int feature_idx) {
auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
col->set_name(feature->feature_name());
col->set_type(dataset::proto::ColumnType::CATEGORICAL);
TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
TF_RETURN_IF_ERROR(set_num_examples(feature->indexed_data().size(),
feature->NumBatches()));
// Don't prune the label feature vocabulary.
if (feature->feature_name() == label_feature_) {
col->mutable_categorical()->set_min_value_count(1);
col->mutable_categorical()->set_max_number_of_unique_values(-1);
}
return tf::Status::OK();
},
[&](SimpleMLCategoricalIntFeature::Resource* feature,
const int feature_idx) {
auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
col->set_name(feature->feature_name());
col->set_type(dataset::proto::ColumnType::CATEGORICAL);
TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
col->mutable_categorical()->set_is_already_integerized(true);
return set_num_examples(feature->data().size(), feature->NumBatches());
},
[&](SimpleMLCategoricalSetStringFeature::Resource* feature,
const int feature_idx) {
auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
col->set_name(feature->feature_name());
col->set_type(dataset::proto::ColumnType::CATEGORICAL_SET);
TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
return set_num_examples(feature->num_examples(),
feature->num_batches());
},
[&](SimpleMLCategoricalSetIntFeature::Resource* feature,
const int feature_idx) {
auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
col->set_name(feature->feature_name());
col->set_type(dataset::proto::ColumnType::CATEGORICAL_SET);
TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
col->mutable_categorical()->set_is_already_integerized(true);
return set_num_examples(feature->num_examples(),
feature->num_batches());
},
[&](SimpleMLHashFeature::Resource* feature, const int feature_idx) {
auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
col->set_name(feature->feature_name());
col->set_type(dataset::proto::ColumnType::HASH);
TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
return set_num_examples(feature->data().size(), feature->NumBatches());
}));
LOG(INFO) << "Number of batches: " << num_batches;
LOG(INFO) << "Number of examples: " << num_examples;
if (num_examples <= 0) {
return tf::Status(tf::error::Code::INVALID_ARGUMENT,
"No training examples available.");
}
TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(dataset->CreateColumnsFromDataspec());
dataset->mutable_data_spec()->set_created_num_rows(num_examples);
dataset::proto::DataSpecificationAccumulator accumulator;
dataset::InitializeDataspecAccumulator(dataset->data_spec(), &accumulator);
TF_RETURN_IF_ERROR(IterateFeatures(
[&](SimpleMLNumericalFeature::Resource* feature, const int feature_idx) {
auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
auto* col_acc = accumulator.mutable_columns(feature_idx);
for (const auto value : feature->data()) {
TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(
dataset::UpdateNumericalColumnSpec(value, col, col_acc));
}
return tf::Status::OK();
},
[&](SimpleMLCategoricalStringFeature::Resource* feature,
const int feature_idx) {
auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
auto* col_acc = accumulator.mutable_columns(feature_idx);
const auto& reverse_index = feature->reverse_index();
for (const auto indexed_value : feature->indexed_data()) {
TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(
dataset::UpdateCategoricalStringColumnSpec(
reverse_index[indexed_value], col, col_acc));
}
return tf::Status::OK();
},
[&](SimpleMLCategoricalIntFeature::Resource* feature,
const int feature_idx) {
auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
auto* col_acc = accumulator.mutable_columns(feature_idx);
for (const auto value : feature->data()) {
TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(
dataset::UpdateCategoricalIntColumnSpec(value, col, col_acc));
}
return tf::Status::OK();
},
[&](SimpleMLCategoricalSetStringFeature::Resource* feature,
const int feature_idx) {
auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
auto* col_acc = accumulator.mutable_columns(feature_idx);
for (const auto& value : feature->values()) {
TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(
dataset::UpdateCategoricalStringColumnSpec(value, col, col_acc));
}
return tf::Status::OK();
},
[&](SimpleMLCategoricalSetIntFeature::Resource* feature,
const int feature_idx) {
auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
auto* col_acc = accumulator.mutable_columns(feature_idx);
for (const auto value : feature->values()) {
TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(
dataset::UpdateCategoricalIntColumnSpec(value, col, col_acc));
}
return tf::Status::OK();
},
[&](SimpleMLHashFeature::Resource* feature, const int feature_idx) {
// Nothing to do.
return tf::Status::OK();
}));
dataset::FinalizeComputeSpec({}, accumulator, dataset->mutable_data_spec());
return tf::Status::OK();
}