in tensorflow_decision_forests/tensorflow/ops/training/kernel.cc [659:747]
explicit SimpleMLModelTrainer(tf::OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_ids", &feature_ids_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("label_id", &label_id_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("weight_id", &weight_id_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("model_dir", &model_dir_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("model_id", &model_id_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("learner", &learner_));
std::string serialized_guide;
OP_REQUIRES_OK(ctx, ctx->GetAttr("guide", &serialized_guide));
if (!guide_.ParseFromString(serialized_guide)) {
OP_REQUIRES_OK(ctx, tf::Status(tf::error::INVALID_ARGUMENT,
"Cannot de-serialize guide proto."));
}
std::string hparams;
OP_REQUIRES_OK(ctx, ctx->GetAttr("hparams", &hparams));
if (!hparams_.ParseFromString(hparams)) {
OP_REQUIRES_OK(ctx, tf::Status(tf::error::INVALID_ARGUMENT,
"Cannot de-serialize hparams proto."));
}
int task_idx;
OP_REQUIRES_OK(ctx, ctx->GetAttr("task", &task_idx));
OP_REQUIRES(ctx, model::proto::Task_IsValid(task_idx),
tf::Status(tf::error::INVALID_ARGUMENT, "Unknown task"));
task_ = static_cast<model::proto::Task>(task_idx);
{
std::string serialized_training_config;
OP_REQUIRES_OK(
ctx, ctx->GetAttr("training_config", &serialized_training_config));
if (!training_config_.MergeFromString(serialized_training_config)) {
OP_REQUIRES_OK(
ctx, tf::Status(tf::error::INVALID_ARGUMENT,
"Cannot de-serialize training_config proto."));
}
if (training_config_.has_task()) {
OP_REQUIRES_OK(
ctx,
tf::Status(tf::error::INVALID_ARGUMENT,
"The \"task\" should not be set in the training_config,"
"instead set it as the Op parameter \"task\"."));
}
if (training_config_.has_learner()) {
OP_REQUIRES_OK(
ctx,
tf::Status(
tf::error::INVALID_ARGUMENT,
"The \"learner\" should not be set in the training_config, "
"instead set it as the Op parameter \"learner\"."));
}
if (training_config_.has_label()) {
OP_REQUIRES_OK(
ctx, tf::Status(
tf::error::INVALID_ARGUMENT,
"The \"label\" should not be set in the training_config, "
"instead set it as the Op parameter \"label_id\"."));
}
if (training_config_.has_weight_definition()) {
OP_REQUIRES_OK(ctx,
tf::Status(tf::error::INVALID_ARGUMENT,
"The \"weight_definition\" should not be "
"set in the training_config."));
}
if (training_config_.features_size() > 0) {
OP_REQUIRES_OK(
ctx,
tf::Status(
tf::error::INVALID_ARGUMENT,
"The \"features\" should not be set in the training_config, "
"for this Op they are generated automatically."));
}
}
{
std::string serialized_deployment_config;
OP_REQUIRES_OK(ctx, ctx->GetAttr("deployment_config",
&serialized_deployment_config));
if (!deployment_config_.MergeFromString(serialized_deployment_config)) {
OP_REQUIRES_OK(
ctx, tf::Status(tf::error::INVALID_ARGUMENT,
"Cannot de-serialize deployment_config proto."));
}
}
OP_REQUIRES_OK(
ctx, ctx->GetAttr("has_validation_dataset", &has_validation_dataset_));
}