void Compute()

in tensorflow_decision_forests/tensorflow/ops/training/kernel.cc [751:908]


  void Compute(tf::OpKernelContext* ctx) override {
    LOG(INFO) << "Start Yggdrasil model training";
    LOG(INFO) << "Collect training examples";

    tf::Tensor* success_tensor = nullptr;
    OP_REQUIRES_OK(
        ctx, ctx->allocate_output(0, tf::TensorShape({}), &success_tensor));
    auto success = success_tensor->scalar<bool>();
    success() = true;

    if (!HasTrainingExamples(ctx)) {
      LOG(INFO) << "Not training example available. Ignore training request.";
      success() = false;
      return;
    }

    dataset::VerticalDataset dataset;
    std::string label_feature;
    std::string weight_feature;
    std::vector<std::string> input_features;
    OP_REQUIRES_OK(ctx, CreateTrainingDatasetFromFeatures(
                            ctx, DatasetType::kTraining, &dataset,
                            &label_feature, &weight_feature, &input_features));

    LOG(INFO) << "Training dataset:\n"
              << dataset::PrintHumanReadable(dataset.data_spec(), false);

    std::unique_ptr<dataset::VerticalDataset> valid_dataset;
    if (has_validation_dataset_) {
      LOG(INFO) << "Collect validation dataset";
      valid_dataset = absl::make_unique<dataset::VerticalDataset>();

      std::string valid_label_feature;
      std::string valid_weight_feature;
      std::vector<std::string> valid_input_features;
      OP_REQUIRES_OK(ctx, CreateTrainingDatasetFromFeatures(
                              ctx, DatasetType::kValidation,
                              valid_dataset.get(), &valid_label_feature,
                              &valid_weight_feature, &valid_input_features));

      LOG(INFO) << "Validation dataset:\n"
                << dataset::PrintHumanReadable(valid_dataset->data_spec(),
                                               false);

      if (valid_label_feature != label_feature) {
        OP_REQUIRES_OK(
            ctx, tf::Status(tf::error::INVALID_ARGUMENT,
                            absl::StrCat("Different label in the training and "
                                         "validation dataset: \"",
                                         label_feature, "\" vs \"",
                                         valid_label_feature, "\"")));
      }

      if (valid_weight_feature != weight_feature) {
        OP_REQUIRES_OK(
            ctx,
            tf::Status(
                tf::error::INVALID_ARGUMENT,
                "Different weights in the training and validation dataset."));
      }

      if (valid_input_features != input_features) {
        OP_REQUIRES_OK(
            ctx,
            tf::Status(
                tf::error::INVALID_ARGUMENT,
                "Different features in the training and validation dataset."));
      }
    }

    LOG(INFO) << "Configure learner";
    model::proto::TrainingConfig config = training_config_;
    config.set_learner(learner_);
    config.set_label(label_feature);
    config.set_task(task_);
    if (!weight_feature.empty()) {
      LOG(INFO) << "Use example weight: " << weight_feature
                << " from accumulator: " << weight_id_;
      config.mutable_weight_definition()->set_attribute(weight_feature);
      config.mutable_weight_definition()->mutable_numerical();
    }
    for (const auto& input_feature : input_features) {
      config.add_features(
          dataset::EscapeTrainingConfigFeatureName(input_feature));
    }

    std::unique_ptr<model::AbstractLearner> learner;
    OP_REQUIRES_OK(ctx, utils::FromUtilStatus(GetLearner(config, &learner)));
    OP_REQUIRES_OK(
        ctx, utils::FromUtilStatus(learner->SetHyperParameters(hparams_)));
    *learner->mutable_deployment() = deployment_config_;
    if (!model_dir_.empty()) {
      learner->set_log_directory(tf::io::JoinPath(model_dir_, "train_logs"));
    }

    LOG(INFO) << "Training config:\n"
              << learner->training_config().DebugString();

    LOG(INFO) << "Deployment config:\n" << learner->deployment().DebugString();

    // The following commented code snippet exports the dataset and training
    // configuration so it can be run easily in a debugger by running:
    //
    // bazel run -c opt //third_party/yggdrasil_decision_forests/cli:train --
    // \
    //   --alsologtostderr --output=/tmp/model \
    //   --dataset=tfrecord+tfe:/tmp/dataset.tfe \
    //   --dataspec=/tmp/dataspec.pbtxt \
    //   --config=/tmp/train_config.pbtxt
    //
    // Add the dependency:
    //   //third_party/yggdrasil_decision_forests/dataset:tf_example_io_tfrecord
    //
    /*
    CHECK_OK(SaveVerticalDataset(dataset, "tfrecord+tfe:/tmp/dataset.tfe"));
    CHECK_OK(file::SetTextProto("/tmp/dataspec.pbtxt", dataset.data_spec(),
                                file::Defaults()));
    CHECK_OK(file::SetTextProto("/tmp/train_config.pbtxt",
                                learner->training_config(),
    file::Defaults()));
    */

#ifdef TFDF_STOP_TRAINING_ON_INTERRUPT
    OP_REQUIRES_OK(ctx, interruption::EnableUserInterruption());
    learner->set_stop_training_trigger(&interruption::stop_training);
#endif

    LOG(INFO) << "Train model";
    utils::StatusOr<std::unique_ptr<model::AbstractModel>> model;
    if (valid_dataset) {
      model = learner->TrainWithStatus(dataset, *valid_dataset);
    } else {
      model = learner->TrainWithStatus(dataset);
    }

#ifdef TFDF_STOP_TRAINING_ON_INTERRUPT
    OP_REQUIRES_OK(ctx, interruption::DisableUserInterruption());
#endif

    OP_REQUIRES_OK(ctx, utils::FromUtilStatus(model.status()));

    // Export model to disk.
    if (!model_dir_.empty()) {
      LOG(INFO) << "Export model in log directory: " << model_dir_;
      OP_REQUIRES_OK(ctx, utils::FromUtilStatus(
                              SaveModel(tf::io::JoinPath(model_dir_, "model"),
                                        model.value().get())));
    }

    // Export model to model resource.
    if (!model_id_.empty()) {
      LOG(INFO) << "Save model in resources";
      auto* model_container = new YggdrasilModelContainer();
      *model_container->mutable_model() = std::move(model.value());
      OP_REQUIRES_OK(ctx, ctx->resource_manager()->Create(
                              kModelContainer, model_id_, model_container));
    }
  }