tf::Status FeatureSet::MoveExamplesFromFeaturesToDataset()

in tensorflow_decision_forests/tensorflow/ops/training/kernel.cc [482:647]


tf::Status FeatureSet::MoveExamplesFromFeaturesToDataset(
    tf::OpKernelContext* ctx, dataset::VerticalDataset* dataset) {
  bool first_set_num_rows = true;
  const auto set_num_rows =
      [&first_set_num_rows, &dataset](
          const int64_t num_rows,
          const AbstractFeatureResource* feature) -> tf::Status {
    if (first_set_num_rows) {
      dataset->set_nrow(num_rows);
    } else if (dataset->nrow() != num_rows) {
      return tf::Status(
          tf::error::Code::INVALID_ARGUMENT,
          absl::Substitute(
              "Inconsistent number of observations "
              "between features for feature $0 != $1. For feature $2.",
              dataset->nrow(), num_rows, feature->feature_name()));
    }
    return tf::Status::OK();
  };

  TF_RETURN_IF_ERROR(IterateFeatures(
      [&](SimpleMLNumericalFeature::Resource* feature, const int feature_idx) {
        TF_RETURN_IF_ERROR(set_num_rows(feature->data().size(), feature));
        auto* col_data = dataset->MutableColumnWithCast<
            dataset::VerticalDataset::NumericalColumn>(feature_idx);
        *col_data->mutable_values() = std::move(*feature->mutable_data());
        feature->mutable_data()->clear();
        return tf::Status::OK();
      },
      [&](SimpleMLCategoricalStringFeature::Resource* feature,
          const int feature_idx) {
        TF_RETURN_IF_ERROR(
            set_num_rows(feature->indexed_data().size(), feature));
        const auto& col_spec = dataset->data_spec().columns(feature_idx);
        auto* col_data = dataset->MutableColumnWithCast<
            dataset::VerticalDataset::CategoricalColumn>(feature_idx);
        col_data->Resize(0);
        const auto& reverse_index = feature->reverse_index();
        for (const auto& indexed_value : feature->indexed_data()) {
          const auto& value = reverse_index[indexed_value];
          if (value.empty()) {
            col_data->AddNA();
          } else {
            col_data->Add(dataset::CategoricalStringToValue(value, col_spec));
          }
        }
        // Note: Thread annotations don't work in lambdas.
        feature->non_mutex_protected_clear();
        return tf::Status::OK();
      },
      [&](SimpleMLCategoricalIntFeature::Resource* feature,
          const int feature_idx) {
        TF_RETURN_IF_ERROR(set_num_rows(feature->data().size(), feature));
        const auto& col_spec = dataset->data_spec().columns(feature_idx);
        auto* col_data = dataset->MutableColumnWithCast<
            dataset::VerticalDataset::CategoricalColumn>(feature_idx);
        col_data->Resize(0);
        for (int value : feature->data()) {
          if (value < dataset::VerticalDataset::CategoricalColumn::kNaValue) {
            // Treated as missing value.
            value = dataset::VerticalDataset::CategoricalColumn::kNaValue;
          }
          if (value >= col_spec.categorical().number_of_unique_values()) {
            // Treated as out-of-dictionary.
            value = 0;
          }
          col_data->Add(value);
        }
        feature->mutable_data()->clear();
        return tf::Status::OK();
      },
      [&](SimpleMLCategoricalSetStringFeature::Resource* feature,
          const int feature_idx) {
        TF_RETURN_IF_ERROR(set_num_rows(feature->num_examples(), feature));
        const auto& col_spec = dataset->data_spec().columns(feature_idx);
        auto* col_data = dataset->MutableColumnWithCast<
            dataset::VerticalDataset::CategoricalSetColumn>(feature_idx);
        col_data->Resize(0);

        // Temporary buffer for the copy.
        std::vector<int> tmp_value;

        const int num_examples = feature->num_examples();
        for (int example_idx = 0; example_idx < num_examples; example_idx++) {
          // Get and convert the values.
          tmp_value.clear();
          const int begin_value_idx = feature->row_splits()[example_idx];
          const int end_value_idx = feature->row_splits()[example_idx + 1];
          for (int value_idx = begin_value_idx; value_idx < end_value_idx;
               value_idx++) {
            const auto& value_str = feature->values()[value_idx];
            const int32_t value =
                dataset::CategoricalStringToValue(value_str, col_spec);
            tmp_value.push_back(value);
          }

          // Store the values.
          std::sort(tmp_value.begin(), tmp_value.end());
          tmp_value.erase(std::unique(tmp_value.begin(), tmp_value.end()),
                          tmp_value.end());

          col_data->AddVector(tmp_value);
        }
        feature->non_mutex_protected_clear();
        return tf::Status::OK();
      },
      [&](SimpleMLCategoricalSetIntFeature::Resource* feature,
          const int feature_idx) {
        TF_RETURN_IF_ERROR(set_num_rows(feature->num_examples(), feature));
        const auto& col_spec = dataset->data_spec().columns(feature_idx);
        auto* col_data = dataset->MutableColumnWithCast<
            dataset::VerticalDataset::CategoricalSetColumn>(feature_idx);
        col_data->Resize(0);

        // Temporary buffer for the copy.
        std::vector<int> tmp_value;

        const int num_examples = feature->num_examples();
        for (int example_idx = 0; example_idx < num_examples; example_idx++) {
          // Get and check the values.
          tmp_value.clear();
          const int begin_value_idx = feature->row_splits()[example_idx];
          const int end_value_idx = feature->row_splits()[example_idx + 1];
          for (int value_idx = begin_value_idx; value_idx < end_value_idx;
               value_idx++) {
            if (value_idx < 0 || value_idx >= feature->values().size()) {
              return tf::Status(tf::error::Code::INTERNAL, "Internal error");
            }
            auto value = feature->values()[value_idx];
            if (value < dataset::VerticalDataset::CategoricalColumn::kNaValue) {
              return tf::Status(
                  tf::error::Code::INVALID_ARGUMENT,
                  absl::StrCat("Integer categorical value should "
                               "be >= -1. Found  value",
                               value, " for feature", feature->feature_name()));
            }
            if (value >= col_spec.categorical().number_of_unique_values()) {
              // Treated as out-of-dictionary.
              value = 0;
            }
            tmp_value.push_back(value);
          }

          // Store the values.
          std::sort(tmp_value.begin(), tmp_value.end());
          tmp_value.erase(std::unique(tmp_value.begin(), tmp_value.end()),
                          tmp_value.end());

          col_data->AddVector(tmp_value);
        }
        feature->non_mutex_protected_clear();
        return tf::Status::OK();
      },
      [&](SimpleMLHashFeature::Resource* feature, const int feature_idx) {
        TF_RETURN_IF_ERROR(set_num_rows(feature->data().size(), feature));
        auto* col_data =
            dataset
                ->MutableColumnWithCast<dataset::VerticalDataset::HashColumn>(
                    feature_idx);
        *col_data->mutable_values() = std::move(*feature->mutable_data());
        feature->mutable_data()->clear();
        return tf::Status::OK();
      }));

  return tf::Status::OK();
}