in tensorflow_decision_forests/keras/wrapper/wrapper.cc [228:576]
ydf::utils::StatusOr<std::string> GenKerasPythonWrapper() {
const auto prefix = "";
std::string imports = absl::Substitute(R"(
from $0tensorflow_decision_forests.keras import core
from $0yggdrasil_decision_forests.model import abstract_model_pb2 # pylint: disable=unused-import
from $0yggdrasil_decision_forests.learner import abstract_learner_pb2
)",
prefix);
std::string wrapper =
absl::Substitute(R"(r"""Wrapper around each learning algorithm.
This file is generated automatically by running the following commands:
bazel build -c opt //tensorflow_decision_forests/keras:wrappers
bazel-bin/tensorflow_decision_forests/keras/wrappers_wrapper_main\
> tensorflow_decision_forests/keras/wrappers_pre_generated.py
Please don't change this file directly. Instead, changes the source. The
documentation source is contained in the "GetGenericHyperParameterSpecification"
method of each learner e.g. GetGenericHyperParameterSpecification in
learner/gradient_boosted_trees/gradient_boosted_trees.cc contains the
documentation (and meta-data) used to generate this file.
"""
from typing import Optional, List, Set
import tensorflow as tf
$0
TaskType = "abstract_model_pb2.Task" # pylint: disable=invalid-name
AdvancedArguments = core.AdvancedArguments
)",
imports);
for (const auto& learner_key : ydf::model::AllRegisteredLearners()) {
const auto class_name = LearnerKeyToClassName(learner_key);
// Get a learner instance.
std::unique_ptr<ydf::model::AbstractLearner> learner;
ydf::model::proto::TrainingConfig train_config;
train_config.set_learner(learner_key);
train_config.set_label("my_label");
RETURN_IF_ERROR(GetLearner(train_config, &learner));
ASSIGN_OR_RETURN(const auto specifications,
learner->GetGenericHyperParameterSpecification());
// Python documentation.
std::string fields_documentation;
// Constructor arguments.
std::string fields_constructor;
// Use of constructor arguments the parameter dictionary.
std::string fields_dict;
// Sort the fields alphabetically.
std::vector<std::string> field_names;
field_names.reserve(specifications.fields_size());
for (const auto& field : specifications.fields()) {
field_names.push_back(field.first);
}
std::sort(field_names.begin(), field_names.end());
for (const auto& field_name : field_names) {
const auto& field_def = specifications.fields().find(field_name)->second;
if (field_def.documentation().deprecated()) {
// Deprecated fields are not exported.
continue;
}
// Constructor argument.
if (!fields_constructor.empty()) {
absl::StrAppend(&fields_constructor, ",\n");
}
// Type of the attribute.
std::string attr_py_type;
// Default value of the attribute.
std::string attr_py_default_value;
if (ydf::utils::HyperParameterIsBoolean(field_def)) {
// Boolean values are stored as categorical.
attr_py_type = "bool";
attr_py_default_value =
(field_def.categorical().default_value() == "true") ? "True"
: "False";
} else {
switch (field_def.Type_case()) {
case ydf::model::proto::GenericHyperParameterSpecification::Value::
kCategorical: {
attr_py_type = "str";
absl::SubstituteAndAppend(&attr_py_default_value, "\"$0\"",
field_def.categorical().default_value());
} break;
case ydf::model::proto::GenericHyperParameterSpecification::Value::
kInteger:
attr_py_type = "int";
absl::StrAppend(&attr_py_default_value,
field_def.integer().default_value());
break;
case ydf::model::proto::GenericHyperParameterSpecification::Value::
kReal:
attr_py_type = "float";
absl::StrAppend(&attr_py_default_value,
PythonFloat(field_def.real().default_value()));
break;
case ydf::model::proto::GenericHyperParameterSpecification::Value::
kCategoricalList:
attr_py_type = "List[str]";
attr_py_default_value = "None";
break;
case ydf::model::proto::GenericHyperParameterSpecification::Value::
TYPE_NOT_SET:
return absl::InvalidArgumentError(
absl::Substitute("Missing type for field $0", field_name));
}
}
// If the parameter is conditional on a parent parameter values, and the
// default value of the parent parameter does not satisfy the condition,
// the default value is set to None.
if (field_def.has_conditional()) {
const auto& conditional = field_def.conditional();
const auto& parent_field =
specifications.fields().find(conditional.control_field());
if (parent_field == specifications.fields().end()) {
return absl::InvalidArgumentError(
absl::Substitute("Unknown conditional field $0 for field $1",
conditional.control_field(), field_name));
}
ASSIGN_OR_RETURN(const auto condition,
ydf::utils::SatisfyDefaultCondition(
parent_field->second, conditional));
if (!condition) {
attr_py_default_value = "None";
}
}
// Constructor argument.
absl::SubstituteAndAppend(&fields_constructor,
" $0: Optional[$1] = $2", field_name,
attr_py_type, attr_py_default_value);
// Assignation to parameter dictionary.
absl::SubstituteAndAppend(
&fields_dict, " \"$0\" : $0,\n", field_name);
// Documentation
if (field_def.documentation().description().empty()) {
// Refer to the proto.
absl::SubstituteAndAppend(&fields_documentation, " $0: See $1\n",
field_name,
field_def.documentation().proto_path());
} else {
// Actual documentation.
absl::StrAppend(
&fields_documentation,
FormatDocumentation(
absl::StrCat(field_name, ": ",
field_def.documentation().description(),
" Default: ", attr_py_default_value, "."),
/*leading_spaces_first_line=*/4,
/*leading_spaces_next_lines=*/6));
}
}
// Pre-configured hyper-parameters.
std::string predefined_hp_doc;
std::string predefined_hp_list;
ASSIGN_OR_RETURN(std::tie(predefined_hp_doc, predefined_hp_list),
BuildPredefinedHyperParameter(learner.get()));
const auto free_text_documentation =
FormatDocumentation(specifications.documentation().description(),
/*leading_spaces_first_line=*/2 - 2,
/*leading_spaces_next_lines=*/2);
const auto nice_learner_name = LearnerKeyToNiceLearnerName(learner_key);
absl::SubstituteAndAppend(
&wrapper, R"(
class $0(core.CoreModel):
r"""$6 learning algorithm.
$5
Usage example:
```python
import tensorflow_decision_forests as tfdf
import pandas as pd
dataset = pd.read_csv("project/dataset.csv")
tf_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(dataset, label="my_label")
model = tfdf.keras.$0()
model.fit(tf_dataset)
print(model.summary())
```
Attributes:
task: Task to solve (e.g. Task.CLASSIFICATION, Task.REGRESSION,
Task.RANKING, Task.CATEGORICAL_UPLIFT).
features: Specify the list and semantic of the input features of the model.
If not specified, all the available features will be used. If specified
and if `exclude_non_specified_features=True`, only the features in
`features` will be used by the model. If "preprocessing" is used,
`features` corresponds to the output of the preprocessing. In this case,
it is recommended for the preprocessing to return a dictionary of tensors.
exclude_non_specified_features: If true, only use the features specified in
`features`.
preprocessing: Functional keras model or @tf.function to apply on the input
feature before the model to train. This preprocessing model can consume
and return tensors, list of tensors or dictionary of tensors. If
specified, the model only "sees" the output of the preprocessing (and not
the raw input). Can be used to prepare the features or to stack multiple
models on top of each other. Unlike preprocessing done in the tf.dataset,
the operation in "preprocessing" are serialized with the model.
postprocessing: Like "preprocessing" but applied on the model output.
ranking_group: Only for `task=Task.RANKING`. Name of a tf.string feature that
identifies queries in a query/document ranking task. The ranking group
is not added automatically for the set of features if
`exclude_non_specified_features=false`.
uplift_treatment: Only for task=Task.CATEGORICAL_UPLIFT. Name of an integer
feature that identifies the treatment in an uplift problem. The value 0 is
reserved for the control treatment.
temp_directory: Temporary directory used to store the model Assets after the
training, and possibly as a work directory during the training. This
temporary directory is necessary for the model to be exported after
training e.g. `model.save(path)`. If not specified, `temp_directory` is
set to a temporary directory using `tempfile.TemporaryDirectory`. This
directory is deleted when the model python object is garbage-collected.
verbose: Verbosity mode. 0 = silent, 1 = small details, 2 = full details.
hyperparameter_template: Override the default value of the hyper-parameters.
If None (default) the default parameters of the library are used. If set,
`default_hyperparameter_template` refers to one of the following
preconfigured hyper-parameter sets. Those sets outperforms the default
hyper-parameters (either generally or in specific scenarios).
You can omit the version (e.g. remove "@v5") to use the last version of
the template. In this case, the hyper-parameter can change in between
releases (not recommended for training in production).
$7
advanced_arguments: Advanced control of the model that most users won't need
to use. See `AdvancedArguments` for details.
num_threads: Number of threads used to train the model. Different learning
algorithms use multi-threading differently and with different degree of
efficiency. If `None`, `num_threads` will be automatically set to the
number of processors (up to a maximum of 32; or set to 6 if the number of
processors is not available).
Making `num_threads` significantly larger than the number of processors
can slow-down the training speed. The default value logic might change in
the future.
name: The name of the model.
max_vocab_count: Default maximum size of the vocabulary for CATEGORICAL and
CATEGORICAL_SET features stored as strings. If more unique values exist,
only the most frequent values are kept, and the remaining values are
considered as out-of-vocabulary. The value `max_vocab_count` defined in a
`FeatureUsage` (if any) takes precedence.
try_resume_training: If true, the model training resumes from the checkpoint
stored in the `temp_directory` directory. If `temp_directory` does not
contain any model checkpoint, the training start from the beginning.
Resuming training is useful in the following situations: (1) The training
was interrupted by the user (e.g. ctrl+c or "stop" button in a
notebook). (2) the training job was interrupted (e.g. rescheduling), ond
(3) the hyper-parameter of the model were changed such that an initially
completed training is now incomplete (e.g. increasing the number of
trees).
Note: Training can only be resumed if the training datasets is exactly the
same (i.e. no reshuffle in the tf.data.Dataset).
check_dataset: If set to true, test if the dataset is well configured for
the training: (1) Check if the dataset does contains any `repeat`
operations, (2) Check if the dataset does contain a `batch` operation,
(3) Check if the dataset has a large enough batch size (min 100 if the
dataset contains more than 1k examples or if the number of examples is
not available) If set to false, do not run any test.
$2
"""
@core._list_explicit_arguments
def __init__(self,
task: Optional[TaskType] = core.Task.CLASSIFICATION,
features: Optional[List[core.FeatureUsage]] = None,
exclude_non_specified_features: Optional[bool] = False,
preprocessing: Optional["tf.keras.models.Functional"] = None,
postprocessing: Optional["tf.keras.models.Functional"] = None,
ranking_group: Optional[str] = None,
uplift_treatment: Optional[str] = None,
temp_directory: Optional[str] = None,
verbose: int = 1,
hyperparameter_template: Optional[str] = None,
advanced_arguments: Optional[AdvancedArguments] = None,
num_threads: Optional[int] = None,
name: Optional[str] = None,
max_vocab_count : Optional[int] = 2000,
try_resume_training: Optional[bool] = True,
check_dataset: Optional[bool] = True,
$3,
explicit_args: Optional[Set[str]] = None):
learner_params = {
$4
}
if hyperparameter_template is not None:
learner_params = core._apply_hp_template(learner_params,
hyperparameter_template, self.predefined_hyperparameters(),
explicit_args)
super($0, self).__init__(task=task,
learner="$1",
learner_params=learner_params,
features=features,
exclude_non_specified_features=exclude_non_specified_features,
preprocessing=preprocessing,
postprocessing=postprocessing,
ranking_group=ranking_group,
uplift_treatment=uplift_treatment,
temp_directory=temp_directory,
verbose=verbose,
advanced_arguments=advanced_arguments,
num_threads=num_threads,
name=name,
max_vocab_count=max_vocab_count,
try_resume_training=try_resume_training,
check_dataset=check_dataset)
@staticmethod
def predefined_hyperparameters() -> List[core.HyperParameterTemplate]:
return $8
@staticmethod
def capabilities() -> abstract_learner_pb2.LearnerCapabilities:
return abstract_learner_pb2.LearnerCapabilities(
support_partial_cache_dataset_format=$9)
)",
/*$0*/ class_name, /*$1*/ learner_key,
/*$2*/ fields_documentation,
/*$3*/ fields_constructor, /*$4*/ fields_dict,
/*$5*/ free_text_documentation,
/*$6*/ nice_learner_name,
/*$7*/ FormatDocumentation(predefined_hp_doc, 6, 6),
/*$8*/ predefined_hp_list,
/*$9*/ learner->Capabilities().support_partial_cache_dataset_format()
? "True"
: "False");
}
// TODO(gbm): Text serialize the proto
return wrapper;
}