in tensorflow_decision_forests/tensorflow/ops/training/kernel.cc [751:908]
void Compute(tf::OpKernelContext* ctx) override {
LOG(INFO) << "Start Yggdrasil model training";
LOG(INFO) << "Collect training examples";
tf::Tensor* success_tensor = nullptr;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, tf::TensorShape({}), &success_tensor));
auto success = success_tensor->scalar<bool>();
success() = true;
if (!HasTrainingExamples(ctx)) {
LOG(INFO) << "Not training example available. Ignore training request.";
success() = false;
return;
}
dataset::VerticalDataset dataset;
std::string label_feature;
std::string weight_feature;
std::vector<std::string> input_features;
OP_REQUIRES_OK(ctx, CreateTrainingDatasetFromFeatures(
ctx, DatasetType::kTraining, &dataset,
&label_feature, &weight_feature, &input_features));
LOG(INFO) << "Training dataset:\n"
<< dataset::PrintHumanReadable(dataset.data_spec(), false);
std::unique_ptr<dataset::VerticalDataset> valid_dataset;
if (has_validation_dataset_) {
LOG(INFO) << "Collect validation dataset";
valid_dataset = absl::make_unique<dataset::VerticalDataset>();
std::string valid_label_feature;
std::string valid_weight_feature;
std::vector<std::string> valid_input_features;
OP_REQUIRES_OK(ctx, CreateTrainingDatasetFromFeatures(
ctx, DatasetType::kValidation,
valid_dataset.get(), &valid_label_feature,
&valid_weight_feature, &valid_input_features));
LOG(INFO) << "Validation dataset:\n"
<< dataset::PrintHumanReadable(valid_dataset->data_spec(),
false);
if (valid_label_feature != label_feature) {
OP_REQUIRES_OK(
ctx, tf::Status(tf::error::INVALID_ARGUMENT,
absl::StrCat("Different label in the training and "
"validation dataset: \"",
label_feature, "\" vs \"",
valid_label_feature, "\"")));
}
if (valid_weight_feature != weight_feature) {
OP_REQUIRES_OK(
ctx,
tf::Status(
tf::error::INVALID_ARGUMENT,
"Different weights in the training and validation dataset."));
}
if (valid_input_features != input_features) {
OP_REQUIRES_OK(
ctx,
tf::Status(
tf::error::INVALID_ARGUMENT,
"Different features in the training and validation dataset."));
}
}
LOG(INFO) << "Configure learner";
model::proto::TrainingConfig config = training_config_;
config.set_learner(learner_);
config.set_label(label_feature);
config.set_task(task_);
if (!weight_feature.empty()) {
LOG(INFO) << "Use example weight: " << weight_feature
<< " from accumulator: " << weight_id_;
config.mutable_weight_definition()->set_attribute(weight_feature);
config.mutable_weight_definition()->mutable_numerical();
}
for (const auto& input_feature : input_features) {
config.add_features(
dataset::EscapeTrainingConfigFeatureName(input_feature));
}
std::unique_ptr<model::AbstractLearner> learner;
OP_REQUIRES_OK(ctx, utils::FromUtilStatus(GetLearner(config, &learner)));
OP_REQUIRES_OK(
ctx, utils::FromUtilStatus(learner->SetHyperParameters(hparams_)));
*learner->mutable_deployment() = deployment_config_;
if (!model_dir_.empty()) {
learner->set_log_directory(tf::io::JoinPath(model_dir_, "train_logs"));
}
LOG(INFO) << "Training config:\n"
<< learner->training_config().DebugString();
LOG(INFO) << "Deployment config:\n" << learner->deployment().DebugString();
// The following commented code snippet exports the dataset and training
// configuration so it can be run easily in a debugger by running:
//
// bazel run -c opt //third_party/yggdrasil_decision_forests/cli:train --
// \
// --alsologtostderr --output=/tmp/model \
// --dataset=tfrecord+tfe:/tmp/dataset.tfe \
// --dataspec=/tmp/dataspec.pbtxt \
// --config=/tmp/train_config.pbtxt
//
// Add the dependency:
// //third_party/yggdrasil_decision_forests/dataset:tf_example_io_tfrecord
//
/*
CHECK_OK(SaveVerticalDataset(dataset, "tfrecord+tfe:/tmp/dataset.tfe"));
CHECK_OK(file::SetTextProto("/tmp/dataspec.pbtxt", dataset.data_spec(),
file::Defaults()));
CHECK_OK(file::SetTextProto("/tmp/train_config.pbtxt",
learner->training_config(),
file::Defaults()));
*/
#ifdef TFDF_STOP_TRAINING_ON_INTERRUPT
OP_REQUIRES_OK(ctx, interruption::EnableUserInterruption());
learner->set_stop_training_trigger(&interruption::stop_training);
#endif
LOG(INFO) << "Train model";
utils::StatusOr<std::unique_ptr<model::AbstractModel>> model;
if (valid_dataset) {
model = learner->TrainWithStatus(dataset, *valid_dataset);
} else {
model = learner->TrainWithStatus(dataset);
}
#ifdef TFDF_STOP_TRAINING_ON_INTERRUPT
OP_REQUIRES_OK(ctx, interruption::DisableUserInterruption());
#endif
OP_REQUIRES_OK(ctx, utils::FromUtilStatus(model.status()));
// Export model to disk.
if (!model_dir_.empty()) {
LOG(INFO) << "Export model in log directory: " << model_dir_;
OP_REQUIRES_OK(ctx, utils::FromUtilStatus(
SaveModel(tf::io::JoinPath(model_dir_, "model"),
model.value().get())));
}
// Export model to model resource.
if (!model_id_.empty()) {
LOG(INFO) << "Save model in resources";
auto* model_container = new YggdrasilModelContainer();
*model_container->mutable_model() = std::move(model.value());
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Create(
kModelContainer, model_id_, model_container));
}
}