def fit_on_dataset_path()

in tensorflow_decision_forests/keras/core.py [0:0]


  def fit_on_dataset_path(
      self,
      train_path: str,
      label_key: str,
      weight_key: Optional[str] = None,
      ranking_key: Optional[str] = None,
      valid_path: Optional[str] = None,
      dataset_format: Optional[str] = "csv",
      max_num_scanned_rows_to_accumulate_statistics: Optional[int] = 100_000,
      try_resume_training: Optional[bool] = True,
      input_model_signature_fn: Optional[tf_core.InputModelSignatureFn] = (
          tf_core.build_default_input_model_signature)):
    """Trains the model on a dataset stored on disk.

    This solution is generally more efficient and easier that loading the
    dataset with a tf.Dataset both for local and distributed training.

    Usage example:

      # Local training
      model = model = keras.GradientBoostedTreesModel()
      model.fit_on_dataset_path(
        train_path="/path/to/dataset.csv",
        label_key="label",
        dataset_format="csv")
      model.save("/model/path")

      # Distributed training
      with tf.distribute.experimental.ParameterServerStrategy(...).scope():
        model = model = keras.DistributedGradientBoostedTreesModel()
      model.fit_on_dataset_path(
        train_path="/path/to/dataset@10",
        label_key="label",
        dataset_format="tfrecord+tfe")
      model.save("/model/path")

    Args:
       train_path: Path to the training dataset. Support comma separated files,
         shard and glob notation.
       label_key: Name of the label column.
       weight_key: Name of the weighing column.
       ranking_key: Name of the ranking column.
       valid_path: Path to the validation dataset. If not provided, or if the
         learning algorithm does not support/need a validation dataset,
         `valid_path` is ignored.
       dataset_format: Format of the dataset. Should be one of the registered
         dataset format (see
         https://github.com/google/yggdrasil-decision-forests/blob/main/documentation/user_manual.md#dataset-path-and-format
           for more details). The format "csv" always available but it is
           generally only suited for small datasets.
      max_num_scanned_rows_to_accumulate_statistics: Maximum number of examples
        to scan to determine the statistics of the features (i.e. the dataspec,
        e.g. mean value, dictionaries). (Currently) the "first" examples of the
        dataset are scanned (e.g. the first examples of the dataset is a single
        file). Therefore, it is important that the sampled dataset is relatively
        uniformly sampled, notably the scanned examples should contains all the
        possible categorical values (otherwise the not seen value will be
        treated as out-of-vocabulary). If set to None, the entire dataset is
        scanned. This parameter has no effect if the dataset is stored in a
        format that already contains those values.
      try_resume_training: If true, tries to resume training from the model
        checkpoint stored in the `temp_directory` directory. If `temp_directory`
        does not contain any model checkpoint, start the training from the
        start. Works in the following three situations: (1) The training was
          interrupted by the user (e.g. ctrl+c). (2) the training job was
          interrupted (e.g. rescheduling), ond (3) the hyper-parameter of the
          model were changed such that an initially completed training is now
          incomplete (e.g. increasing the number of trees).
      input_model_signature_fn: A lambda that returns the
        (Dense,Sparse,Ragged)TensorSpec (or structure of TensorSpec e.g.
        dictionary, list) corresponding to input signature of the model. If not
        specified, the input model signature is created by
        "build_default_input_model_signature". For example, specify
        "input_model_signature_fn" if an numerical input feature (which is
        consumed as DenseTensorSpec(float32) by default) will be feed
        differently (e.g. RaggedTensor(int64)).

    Returns:
      A `History` object. Its `History.history` attribute is not yet
      implemented for decision forests algorithms, and will return empty.
      All other fields are filled as usual for `Keras.Mode.fit()`.
    """

    self._time_begin_training = datetime.now()

    if self._verbose >= 1:
      tf_logging.info("Training model on dataset %s", train_path)

    self._clear_function_cache()

    # Call "compile" if the user forgot to do so.
    if not self._is_compiled:
      self.compile()

    train_model_path = self._temp_directory
    model_path = os.path.join(train_model_path, "model")

    # Create the dataspec guide.
    guide = data_spec_pb2.DataSpecificationGuide(
        ignore_columns_without_guides=self._exclude_non_specified,
        max_num_scanned_rows_to_accumulate_statistics=max_num_scanned_rows_to_accumulate_statistics
    )
    guide.default_column_guide.categorial.max_vocab_count = self._max_vocab_count
    self._normalized_input_keys = []
    for feature in self._features:
      col_guide = copy.deepcopy(feature.guide)
      col_guide.column_name_pattern = tf_core.normalize_inputs_regexp(
          feature.name)
      guide.column_guides.append(col_guide)
      self._normalized_input_keys.append(feature.name)

    label_guide = data_spec_pb2.ColumnGuide(
        column_name_pattern=tf_core.normalize_inputs_regexp(label_key))

    if self._task == Task.CLASSIFICATION:
      label_guide.type = data_spec_pb2.CATEGORICAL
      label_guide.categorial.min_vocab_frequency = 0
      label_guide.categorial.max_vocab_count = -1
    elif self._task == Task.REGRESSION:
      label_guide.type = data_spec_pb2.NUMERICAL
    elif self._task == Task.RANKING:
      label_guide.type = data_spec_pb2.NUMERICAL
    else:
      raise ValueError(
          f"Non implemented task {self._task} with \"fit_on_dataset_path\"."
          " Use a different task or train with \"fit\".")
    guide.column_guides.append(label_guide)

    if ranking_key:
      ranking_guide = data_spec_pb2.ColumnGuide(
          column_name_pattern=tf_core.normalize_inputs_regexp(ranking_key),
          type=data_spec_pb2.HASH)
      guide.column_guides.append(ranking_guide)

    if weight_key:
      weight_guide = data_spec_pb2.ColumnGuide(
          column_name_pattern=tf_core.normalize_inputs_regexp(weight_key),
          type=data_spec_pb2.NUMERICAL)
      guide.column_guides.append(weight_guide)

    # Deployment configuration
    deployment_config = copy.deepcopy(
        self._advanced_arguments.yggdrasil_deployment_config)
    if not deployment_config.HasField("num_threads"):
      deployment_config.num_threads = self._num_threads

    distribution_config = tf_core.get_distribution_configuration(
        self.distribute_strategy)

    if distribution_config is not None and not self.capabilities(
    ).support_partial_cache_dataset_format:
      raise ValueError(
          f"The model {type(self)} does not support training with a TF "
          "Distribution strategy (i.e. model.capabilities()."
          "support_partial_cache_dataset_format == False). If the dataset "
          "is small, simply remove the distribution strategy scope (i.e. `with "
          "strategy.scope():` around the model construction). If the dataset "
          "is large, use a distributed version of the model. For Example, use "
          "DistributedGradientBoostedTreesModel instead of "
          "GradientBoostedTreesModel.")

    with tf_logging.capture_cpp_log_context(verbose=self._verbose >= 2):
      # Train the model.
      tf_core.train_on_file_dataset(
          train_dataset_path=dataset_format + ":" + train_path,
          valid_dataset_path=(dataset_format + ":" +
                              valid_path) if valid_path else None,
          feature_ids=self._normalized_input_keys,
          label_id=label_key,
          weight_id=weight_key,
          model_id=self._training_model_id,
          model_dir=train_model_path,
          learner=self._learner,
          task=self._task,
          generic_hparms=tf_core.hparams_dict_to_generic_proto(
              self._learner_params),
          ranking_group=ranking_key,
          keep_model_in_resource=True,
          guide=guide,
          training_config=self._advanced_arguments.yggdrasil_training_config,
          deployment_config=deployment_config,
          working_cache_path=os.path.join(self._temp_directory,
                                          "working_cache"),
          distribution_config=distribution_config,
          try_resume_training=try_resume_training)

      self._time_end_training = datetime.now()
      if self._verbose >= 1:
        self._print_timer_training()

      if self._verbose >= 1:
        tf_logging.info("Compiling model")

      # Request and store a description of the model.
      self._description = training_op.SimpleMLShowModel(
          model_identifier=self._training_model_id).numpy().decode("utf-8")
      training_op.SimpleMLUnloadModel(model_identifier=self._training_model_id)

    # Build the model's graph.
    inspector = inspector_lib.make_inspector(model_path)
    self._set_from_yggdrasil_model(
        inspector,
        model_path,
        input_model_signature_fn=input_model_signature_fn)

    # Build the model history.
    history = tf.keras.callbacks.History()
    history.model = self
    history.on_train_begin()

    training_logs = inspector.training_logs()
    if training_logs is not None:
      for src_logs in training_logs:
        if src_logs.evaluation is not None:
          history.on_epoch_end(src_logs.num_trees,
                               src_logs.evaluation.to_dict())
    self.history = history

    return self.history