explicit SimpleMLModelTrainer()

in tensorflow_decision_forests/tensorflow/ops/training/kernel.cc [659:747]


  explicit SimpleMLModelTrainer(tf::OpKernelConstruction* ctx) : OpKernel(ctx) {
    OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_ids", &feature_ids_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("label_id", &label_id_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("weight_id", &weight_id_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("model_dir", &model_dir_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("model_id", &model_id_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("learner", &learner_));

    std::string serialized_guide;
    OP_REQUIRES_OK(ctx, ctx->GetAttr("guide", &serialized_guide));
    if (!guide_.ParseFromString(serialized_guide)) {
      OP_REQUIRES_OK(ctx, tf::Status(tf::error::INVALID_ARGUMENT,
                                     "Cannot de-serialize guide proto."));
    }

    std::string hparams;
    OP_REQUIRES_OK(ctx, ctx->GetAttr("hparams", &hparams));
    if (!hparams_.ParseFromString(hparams)) {
      OP_REQUIRES_OK(ctx, tf::Status(tf::error::INVALID_ARGUMENT,
                                     "Cannot de-serialize hparams proto."));
    }

    int task_idx;
    OP_REQUIRES_OK(ctx, ctx->GetAttr("task", &task_idx));
    OP_REQUIRES(ctx, model::proto::Task_IsValid(task_idx),
                tf::Status(tf::error::INVALID_ARGUMENT, "Unknown task"));
    task_ = static_cast<model::proto::Task>(task_idx);

    {
      std::string serialized_training_config;
      OP_REQUIRES_OK(
          ctx, ctx->GetAttr("training_config", &serialized_training_config));
      if (!training_config_.MergeFromString(serialized_training_config)) {
        OP_REQUIRES_OK(
            ctx, tf::Status(tf::error::INVALID_ARGUMENT,
                            "Cannot de-serialize training_config proto."));
      }
      if (training_config_.has_task()) {
        OP_REQUIRES_OK(
            ctx,
            tf::Status(tf::error::INVALID_ARGUMENT,
                       "The \"task\" should not be set in the training_config,"
                       "instead set it as the Op parameter \"task\"."));
      }
      if (training_config_.has_learner()) {
        OP_REQUIRES_OK(
            ctx,
            tf::Status(
                tf::error::INVALID_ARGUMENT,
                "The \"learner\" should not be set in the training_config, "
                "instead set it as the Op parameter \"learner\"."));
      }
      if (training_config_.has_label()) {
        OP_REQUIRES_OK(
            ctx, tf::Status(
                     tf::error::INVALID_ARGUMENT,
                     "The \"label\" should not be set in the training_config, "
                     "instead set it as the Op parameter \"label_id\"."));
      }
      if (training_config_.has_weight_definition()) {
        OP_REQUIRES_OK(ctx,
                       tf::Status(tf::error::INVALID_ARGUMENT,
                                  "The \"weight_definition\" should not be "
                                  "set in the training_config."));
      }
      if (training_config_.features_size() > 0) {
        OP_REQUIRES_OK(
            ctx,
            tf::Status(
                tf::error::INVALID_ARGUMENT,
                "The \"features\" should not be set in the training_config, "
                "for this Op they are generated automatically."));
      }
    }

    {
      std::string serialized_deployment_config;
      OP_REQUIRES_OK(ctx, ctx->GetAttr("deployment_config",
                                       &serialized_deployment_config));
      if (!deployment_config_.MergeFromString(serialized_deployment_config)) {
        OP_REQUIRES_OK(
            ctx, tf::Status(tf::error::INVALID_ARGUMENT,
                            "Cannot de-serialize deployment_config proto."));
      }
    }

    OP_REQUIRES_OK(
        ctx, ctx->GetAttr("has_validation_dataset", &has_validation_dataset_));
  }