tf::Status FeatureSet::InitializeDatasetFromFeatures()

in tensorflow_decision_forests/tensorflow/ops/training/kernel.cc [306:480]


tf::Status FeatureSet::InitializeDatasetFromFeatures(
    tf::OpKernelContext* ctx,
    const dataset::proto::DataSpecificationGuide& guide,
    dataset::VerticalDataset* dataset) {
  int64_t num_batches = -1;
  int64_t num_examples = -1;
  const auto set_num_examples =
      [&num_examples, &num_batches](
          const int64_t observed_num_examples,
          const int64_t observed_num_batches) -> tf::Status {
    if (num_examples == -1) {
      num_examples = observed_num_examples;
      num_batches = observed_num_batches;
      return tf::Status::OK();
    }
    if (num_examples != observed_num_examples) {
      return tf::Status(
          tf::error::Code::INVALID_ARGUMENT,
          absl::Substitute("Inconsistent number of training examples for the "
                           "different input features $0 != $1.",
                           num_examples, observed_num_examples));
    }
    return tf::Status::OK();
  };

  for (int feature_idx = 0; feature_idx < NumFeatures(); feature_idx++) {
    dataset->mutable_data_spec()->add_columns();
  }

  // Apply the guide on a column. The type of the column should be set.
  const auto apply_guide = [&](const absl::string_view feature_name,
                               dataset::proto::Column* col) -> tf::Status {
    dataset::proto::ColumnGuide col_guide;
    dataset::BuildColumnGuide(feature_name, guide, &col_guide);
    return utils::FromUtilStatus(
        dataset::UpdateSingleColSpecWithGuideInfo(col_guide, col));
  };

  TF_RETURN_IF_ERROR(IterateFeatures(
      [&](SimpleMLNumericalFeature::Resource* feature, const int feature_idx) {
        auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
        col->set_name(feature->feature_name());
        col->set_type(dataset::proto::ColumnType::NUMERICAL);
        TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
        return set_num_examples(feature->data().size(), feature->NumBatches());
      },
      [&](SimpleMLCategoricalStringFeature::Resource* feature,
          const int feature_idx) {
        auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
        col->set_name(feature->feature_name());
        col->set_type(dataset::proto::ColumnType::CATEGORICAL);
        TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
        TF_RETURN_IF_ERROR(set_num_examples(feature->indexed_data().size(),
                                            feature->NumBatches()));

        // Don't prune the label feature vocabulary.
        if (feature->feature_name() == label_feature_) {
          col->mutable_categorical()->set_min_value_count(1);
          col->mutable_categorical()->set_max_number_of_unique_values(-1);
        }

        return tf::Status::OK();
      },
      [&](SimpleMLCategoricalIntFeature::Resource* feature,
          const int feature_idx) {
        auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
        col->set_name(feature->feature_name());
        col->set_type(dataset::proto::ColumnType::CATEGORICAL);
        TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
        col->mutable_categorical()->set_is_already_integerized(true);
        return set_num_examples(feature->data().size(), feature->NumBatches());
      },
      [&](SimpleMLCategoricalSetStringFeature::Resource* feature,
          const int feature_idx) {
        auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
        col->set_name(feature->feature_name());
        col->set_type(dataset::proto::ColumnType::CATEGORICAL_SET);
        TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
        return set_num_examples(feature->num_examples(),
                                feature->num_batches());
      },
      [&](SimpleMLCategoricalSetIntFeature::Resource* feature,
          const int feature_idx) {
        auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
        col->set_name(feature->feature_name());
        col->set_type(dataset::proto::ColumnType::CATEGORICAL_SET);
        TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
        col->mutable_categorical()->set_is_already_integerized(true);
        return set_num_examples(feature->num_examples(),
                                feature->num_batches());
      },
      [&](SimpleMLHashFeature::Resource* feature, const int feature_idx) {
        auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
        col->set_name(feature->feature_name());
        col->set_type(dataset::proto::ColumnType::HASH);
        TF_RETURN_IF_ERROR(apply_guide(feature->feature_name(), col));
        return set_num_examples(feature->data().size(), feature->NumBatches());
      }));

  LOG(INFO) << "Number of batches: " << num_batches;
  LOG(INFO) << "Number of examples: " << num_examples;

  if (num_examples <= 0) {
    return tf::Status(tf::error::Code::INVALID_ARGUMENT,
                      "No training examples available.");
  }

  TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(dataset->CreateColumnsFromDataspec());

  dataset->mutable_data_spec()->set_created_num_rows(num_examples);

  dataset::proto::DataSpecificationAccumulator accumulator;
  dataset::InitializeDataspecAccumulator(dataset->data_spec(), &accumulator);

  TF_RETURN_IF_ERROR(IterateFeatures(
      [&](SimpleMLNumericalFeature::Resource* feature, const int feature_idx) {
        auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
        auto* col_acc = accumulator.mutable_columns(feature_idx);
        for (const auto value : feature->data()) {
          TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(
              dataset::UpdateNumericalColumnSpec(value, col, col_acc));
        }
        return tf::Status::OK();
      },
      [&](SimpleMLCategoricalStringFeature::Resource* feature,
          const int feature_idx) {
        auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
        auto* col_acc = accumulator.mutable_columns(feature_idx);
        const auto& reverse_index = feature->reverse_index();
        for (const auto indexed_value : feature->indexed_data()) {
          TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(
              dataset::UpdateCategoricalStringColumnSpec(
                  reverse_index[indexed_value], col, col_acc));
        }
        return tf::Status::OK();
      },
      [&](SimpleMLCategoricalIntFeature::Resource* feature,
          const int feature_idx) {
        auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
        auto* col_acc = accumulator.mutable_columns(feature_idx);
        for (const auto value : feature->data()) {
          TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(
              dataset::UpdateCategoricalIntColumnSpec(value, col, col_acc));
        }
        return tf::Status::OK();
      },
      [&](SimpleMLCategoricalSetStringFeature::Resource* feature,
          const int feature_idx) {
        auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
        auto* col_acc = accumulator.mutable_columns(feature_idx);
        for (const auto& value : feature->values()) {
          TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(
              dataset::UpdateCategoricalStringColumnSpec(value, col, col_acc));
        }
        return tf::Status::OK();
      },
      [&](SimpleMLCategoricalSetIntFeature::Resource* feature,
          const int feature_idx) {
        auto* col = dataset->mutable_data_spec()->mutable_columns(feature_idx);
        auto* col_acc = accumulator.mutable_columns(feature_idx);
        for (const auto value : feature->values()) {
          TF_RETURN_IF_ERROR_FROM_ABSL_STATUS(
              dataset::UpdateCategoricalIntColumnSpec(value, col, col_acc));
        }
        return tf::Status::OK();
      },
      [&](SimpleMLHashFeature::Resource* feature, const int feature_idx) {
        // Nothing to do.
        return tf::Status::OK();
      }));

  dataset::FinalizeComputeSpec({}, accumulator, dataset->mutable_data_spec());

  return tf::Status::OK();
}