ydf::utils::StatusOr GenKerasPythonWrapper()

in tensorflow_decision_forests/keras/wrapper/wrapper.cc [228:576]


ydf::utils::StatusOr<std::string> GenKerasPythonWrapper() {
  const auto prefix = "";

  std::string imports = absl::Substitute(R"(
from $0tensorflow_decision_forests.keras import core
from $0yggdrasil_decision_forests.model import abstract_model_pb2  # pylint: disable=unused-import
from $0yggdrasil_decision_forests.learner import abstract_learner_pb2
)",
                                         prefix);

  std::string wrapper =
      absl::Substitute(R"(r"""Wrapper around each learning algorithm.

This file is generated automatically by running the following commands:
  bazel build -c opt //tensorflow_decision_forests/keras:wrappers
  bazel-bin/tensorflow_decision_forests/keras/wrappers_wrapper_main\
    > tensorflow_decision_forests/keras/wrappers_pre_generated.py

Please don't change this file directly. Instead, changes the source. The
documentation source is contained in the "GetGenericHyperParameterSpecification"
method of each learner e.g. GetGenericHyperParameterSpecification in
learner/gradient_boosted_trees/gradient_boosted_trees.cc contains the
documentation (and meta-data) used to generate this file.
"""

from typing import Optional, List, Set
import tensorflow as tf
$0
TaskType = "abstract_model_pb2.Task"  # pylint: disable=invalid-name
AdvancedArguments = core.AdvancedArguments

)",
                       imports);

  for (const auto& learner_key : ydf::model::AllRegisteredLearners()) {
    const auto class_name = LearnerKeyToClassName(learner_key);

    // Get a learner instance.
    std::unique_ptr<ydf::model::AbstractLearner> learner;
    ydf::model::proto::TrainingConfig train_config;
    train_config.set_learner(learner_key);
    train_config.set_label("my_label");
    RETURN_IF_ERROR(GetLearner(train_config, &learner));
    ASSIGN_OR_RETURN(const auto specifications,
                     learner->GetGenericHyperParameterSpecification());

    // Python documentation.
    std::string fields_documentation;
    // Constructor arguments.
    std::string fields_constructor;
    // Use of constructor arguments the parameter dictionary.
    std::string fields_dict;

    // Sort the fields alphabetically.
    std::vector<std::string> field_names;
    field_names.reserve(specifications.fields_size());
    for (const auto& field : specifications.fields()) {
      field_names.push_back(field.first);
    }
    std::sort(field_names.begin(), field_names.end());

    for (const auto& field_name : field_names) {
      const auto& field_def = specifications.fields().find(field_name)->second;

      if (field_def.documentation().deprecated()) {
        // Deprecated fields are not exported.
        continue;
      }

      // Constructor argument.
      if (!fields_constructor.empty()) {
        absl::StrAppend(&fields_constructor, ",\n");
      }
      // Type of the attribute.
      std::string attr_py_type;
      // Default value of the attribute.
      std::string attr_py_default_value;

      if (ydf::utils::HyperParameterIsBoolean(field_def)) {
        // Boolean values are stored as categorical.
        attr_py_type = "bool";
        attr_py_default_value =
            (field_def.categorical().default_value() == "true") ? "True"
                                                                : "False";
      } else {
        switch (field_def.Type_case()) {
          case ydf::model::proto::GenericHyperParameterSpecification::Value::
              kCategorical: {
            attr_py_type = "str";
            absl::SubstituteAndAppend(&attr_py_default_value, "\"$0\"",
                                      field_def.categorical().default_value());
          } break;
          case ydf::model::proto::GenericHyperParameterSpecification::Value::
              kInteger:
            attr_py_type = "int";
            absl::StrAppend(&attr_py_default_value,
                            field_def.integer().default_value());
            break;
          case ydf::model::proto::GenericHyperParameterSpecification::Value::
              kReal:
            attr_py_type = "float";
            absl::StrAppend(&attr_py_default_value,
                            PythonFloat(field_def.real().default_value()));
            break;
          case ydf::model::proto::GenericHyperParameterSpecification::Value::
              kCategoricalList:
            attr_py_type = "List[str]";
            attr_py_default_value = "None";
            break;
          case ydf::model::proto::GenericHyperParameterSpecification::Value::
              TYPE_NOT_SET:
            return absl::InvalidArgumentError(
                absl::Substitute("Missing type for field $0", field_name));
        }
      }

      // If the parameter is conditional on a parent parameter values, and the
      // default value of the parent parameter does not satisfy the condition,
      // the default value is set to None.
      if (field_def.has_conditional()) {
        const auto& conditional = field_def.conditional();
        const auto& parent_field =
            specifications.fields().find(conditional.control_field());
        if (parent_field == specifications.fields().end()) {
          return absl::InvalidArgumentError(
              absl::Substitute("Unknown conditional field $0 for field $1",
                               conditional.control_field(), field_name));
        }
        ASSIGN_OR_RETURN(const auto condition,
                         ydf::utils::SatisfyDefaultCondition(
                             parent_field->second, conditional));
        if (!condition) {
          attr_py_default_value = "None";
        }
      }

      // Constructor argument.
      absl::SubstituteAndAppend(&fields_constructor,
                                "      $0: Optional[$1] = $2", field_name,
                                attr_py_type, attr_py_default_value);

      // Assignation to parameter dictionary.
      absl::SubstituteAndAppend(
          &fields_dict, "                      \"$0\" : $0,\n", field_name);

      // Documentation
      if (field_def.documentation().description().empty()) {
        // Refer to the proto.
        absl::SubstituteAndAppend(&fields_documentation, "    $0: See $1\n",
                                  field_name,
                                  field_def.documentation().proto_path());
      } else {
        // Actual documentation.
        absl::StrAppend(
            &fields_documentation,
            FormatDocumentation(
                absl::StrCat(field_name, ": ",
                             field_def.documentation().description(),
                             " Default: ", attr_py_default_value, "."),
                /*leading_spaces_first_line=*/4,
                /*leading_spaces_next_lines=*/6));
      }
    }

    // Pre-configured hyper-parameters.
    std::string predefined_hp_doc;
    std::string predefined_hp_list;
    ASSIGN_OR_RETURN(std::tie(predefined_hp_doc, predefined_hp_list),
                     BuildPredefinedHyperParameter(learner.get()));

    const auto free_text_documentation =
        FormatDocumentation(specifications.documentation().description(),
                            /*leading_spaces_first_line=*/2 - 2,
                            /*leading_spaces_next_lines=*/2);

    const auto nice_learner_name = LearnerKeyToNiceLearnerName(learner_key);

    absl::SubstituteAndAppend(
        &wrapper, R"(
class $0(core.CoreModel):
  r"""$6 learning algorithm.

  $5
  Usage example:

  ```python
  import tensorflow_decision_forests as tfdf
  import pandas as pd

  dataset = pd.read_csv("project/dataset.csv")
  tf_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(dataset, label="my_label")

  model = tfdf.keras.$0()
  model.fit(tf_dataset)

  print(model.summary())
  ```

  Attributes:
    task: Task to solve (e.g. Task.CLASSIFICATION, Task.REGRESSION,
      Task.RANKING, Task.CATEGORICAL_UPLIFT).
    features: Specify the list and semantic of the input features of the model.
      If not specified, all the available features will be used. If specified
      and if `exclude_non_specified_features=True`, only the features in
      `features` will be used by the model. If "preprocessing" is used,
      `features` corresponds to the output of the preprocessing. In this case,
      it is recommended for the preprocessing to return a dictionary of tensors.
    exclude_non_specified_features: If true, only use the features specified in
      `features`.
    preprocessing: Functional keras model or @tf.function to apply on the input
      feature before the model to train. This preprocessing model can consume
      and return tensors, list of tensors or dictionary of tensors. If
      specified, the model only "sees" the output of the preprocessing (and not
      the raw input). Can be used to prepare the features or to stack multiple
      models on top of each other. Unlike preprocessing done in the tf.dataset,
      the operation in "preprocessing" are serialized with the model.
    postprocessing: Like "preprocessing" but applied on the model output.
    ranking_group: Only for `task=Task.RANKING`. Name of a tf.string feature that
      identifies queries in a query/document ranking task. The ranking group
      is not added automatically for the set of features if
      `exclude_non_specified_features=false`.
    uplift_treatment: Only for task=Task.CATEGORICAL_UPLIFT. Name of an integer
      feature that identifies the treatment in an uplift problem. The value 0 is
      reserved for the control treatment.
    temp_directory: Temporary directory used to store the model Assets after the
      training, and possibly as a work directory during the training. This
      temporary directory is necessary for the model to be exported after
      training e.g. `model.save(path)`. If not specified, `temp_directory` is
      set to a temporary directory using `tempfile.TemporaryDirectory`. This
      directory is deleted when the model python object is garbage-collected.
    verbose: Verbosity mode. 0 = silent, 1 = small details, 2 = full details.
    hyperparameter_template: Override the default value of the hyper-parameters.
      If None (default) the default parameters of the library are used. If set,
      `default_hyperparameter_template` refers to one of the following
      preconfigured hyper-parameter sets. Those sets outperforms the default
      hyper-parameters (either generally or in specific scenarios).
      You can omit the version (e.g. remove "@v5") to use the last version of
      the template. In this case, the hyper-parameter can change in between
      releases (not recommended for training in production).
$7
    advanced_arguments: Advanced control of the model that most users won't need
      to use. See `AdvancedArguments` for details.
    num_threads: Number of threads used to train the model. Different learning
      algorithms use multi-threading differently and with different degree of
      efficiency. If `None`, `num_threads` will be automatically set to the
      number of processors (up to a maximum of 32; or set to 6 if the number of
      processors is not available).
      Making `num_threads` significantly larger than the number of processors
      can slow-down the training speed. The default value logic might change in
      the future.
    name: The name of the model.
    max_vocab_count: Default maximum size of the vocabulary for CATEGORICAL and
      CATEGORICAL_SET features stored as strings. If more unique values exist,
      only the most frequent values are kept, and the remaining values are
      considered as out-of-vocabulary. The value `max_vocab_count` defined in a
      `FeatureUsage` (if any) takes precedence.
    try_resume_training: If true, the model training resumes from the checkpoint
      stored in the `temp_directory` directory. If `temp_directory` does not
      contain any model checkpoint, the training start from the beginning.
      Resuming training is useful in the following situations: (1) The training
        was interrupted by the user (e.g. ctrl+c or "stop" button in a
        notebook). (2) the training job was interrupted (e.g. rescheduling), ond
        (3) the hyper-parameter of the model were changed such that an initially
        completed training is now incomplete (e.g. increasing the number of
        trees).
      Note: Training can only be resumed if the training datasets is exactly the
        same (i.e. no reshuffle in the tf.data.Dataset).
    check_dataset: If set to true, test if the dataset is well configured for
      the training: (1) Check if the dataset does contains any `repeat`
        operations, (2) Check if the dataset does contain a `batch` operation,
        (3) Check if the dataset has a large enough batch size (min 100 if the
        dataset contains more than 1k examples or if the number of examples is
        not available) If set to false, do not run any test.
$2
  """

  @core._list_explicit_arguments
  def __init__(self,
      task: Optional[TaskType] = core.Task.CLASSIFICATION,
      features: Optional[List[core.FeatureUsage]] = None,
      exclude_non_specified_features: Optional[bool] = False,
      preprocessing: Optional["tf.keras.models.Functional"] = None,
      postprocessing: Optional["tf.keras.models.Functional"] = None,
      ranking_group: Optional[str] = None,
      uplift_treatment: Optional[str] = None,
      temp_directory: Optional[str] = None,
      verbose: int = 1,
      hyperparameter_template: Optional[str] = None,
      advanced_arguments: Optional[AdvancedArguments] = None,
      num_threads: Optional[int] = None,
      name: Optional[str] = None,
      max_vocab_count : Optional[int] = 2000,
      try_resume_training: Optional[bool] = True,
      check_dataset: Optional[bool] = True,
$3,
      explicit_args: Optional[Set[str]] = None):

    learner_params = {
$4
      }

    if hyperparameter_template is not None:
      learner_params = core._apply_hp_template(learner_params,
        hyperparameter_template, self.predefined_hyperparameters(),
        explicit_args)

    super($0, self).__init__(task=task,
      learner="$1",
      learner_params=learner_params,
      features=features,
      exclude_non_specified_features=exclude_non_specified_features,
      preprocessing=preprocessing,
      postprocessing=postprocessing,
      ranking_group=ranking_group,
      uplift_treatment=uplift_treatment,
      temp_directory=temp_directory,
      verbose=verbose,
      advanced_arguments=advanced_arguments,
      num_threads=num_threads,
      name=name,
      max_vocab_count=max_vocab_count,
      try_resume_training=try_resume_training,
      check_dataset=check_dataset)

  @staticmethod
  def predefined_hyperparameters() -> List[core.HyperParameterTemplate]:
    return $8

  @staticmethod
  def capabilities() -> abstract_learner_pb2.LearnerCapabilities:
    return abstract_learner_pb2.LearnerCapabilities(
      support_partial_cache_dataset_format=$9)
)",
        /*$0*/ class_name, /*$1*/ learner_key,
        /*$2*/ fields_documentation,
        /*$3*/ fields_constructor, /*$4*/ fields_dict,
        /*$5*/ free_text_documentation,
        /*$6*/ nice_learner_name,
        /*$7*/ FormatDocumentation(predefined_hp_doc, 6, 6),
        /*$8*/ predefined_hp_list,
        /*$9*/ learner->Capabilities().support_partial_cache_dataset_format()
            ? "True"
            : "False");
  }

  // TODO(gbm): Text serialize the proto

  return wrapper;
}