in tensorflow_decision_forests/keras/core.py [0:0]
def fit_on_dataset_path(
self,
train_path: str,
label_key: str,
weight_key: Optional[str] = None,
ranking_key: Optional[str] = None,
valid_path: Optional[str] = None,
dataset_format: Optional[str] = "csv",
max_num_scanned_rows_to_accumulate_statistics: Optional[int] = 100_000,
try_resume_training: Optional[bool] = True,
input_model_signature_fn: Optional[tf_core.InputModelSignatureFn] = (
tf_core.build_default_input_model_signature)):
"""Trains the model on a dataset stored on disk.
This solution is generally more efficient and easier that loading the
dataset with a tf.Dataset both for local and distributed training.
Usage example:
# Local training
model = model = keras.GradientBoostedTreesModel()
model.fit_on_dataset_path(
train_path="/path/to/dataset.csv",
label_key="label",
dataset_format="csv")
model.save("/model/path")
# Distributed training
with tf.distribute.experimental.ParameterServerStrategy(...).scope():
model = model = keras.DistributedGradientBoostedTreesModel()
model.fit_on_dataset_path(
train_path="/path/to/dataset@10",
label_key="label",
dataset_format="tfrecord+tfe")
model.save("/model/path")
Args:
train_path: Path to the training dataset. Support comma separated files,
shard and glob notation.
label_key: Name of the label column.
weight_key: Name of the weighing column.
ranking_key: Name of the ranking column.
valid_path: Path to the validation dataset. If not provided, or if the
learning algorithm does not support/need a validation dataset,
`valid_path` is ignored.
dataset_format: Format of the dataset. Should be one of the registered
dataset format (see
https://github.com/google/yggdrasil-decision-forests/blob/main/documentation/user_manual.md#dataset-path-and-format
for more details). The format "csv" always available but it is
generally only suited for small datasets.
max_num_scanned_rows_to_accumulate_statistics: Maximum number of examples
to scan to determine the statistics of the features (i.e. the dataspec,
e.g. mean value, dictionaries). (Currently) the "first" examples of the
dataset are scanned (e.g. the first examples of the dataset is a single
file). Therefore, it is important that the sampled dataset is relatively
uniformly sampled, notably the scanned examples should contains all the
possible categorical values (otherwise the not seen value will be
treated as out-of-vocabulary). If set to None, the entire dataset is
scanned. This parameter has no effect if the dataset is stored in a
format that already contains those values.
try_resume_training: If true, tries to resume training from the model
checkpoint stored in the `temp_directory` directory. If `temp_directory`
does not contain any model checkpoint, start the training from the
start. Works in the following three situations: (1) The training was
interrupted by the user (e.g. ctrl+c). (2) the training job was
interrupted (e.g. rescheduling), ond (3) the hyper-parameter of the
model were changed such that an initially completed training is now
incomplete (e.g. increasing the number of trees).
input_model_signature_fn: A lambda that returns the
(Dense,Sparse,Ragged)TensorSpec (or structure of TensorSpec e.g.
dictionary, list) corresponding to input signature of the model. If not
specified, the input model signature is created by
"build_default_input_model_signature". For example, specify
"input_model_signature_fn" if an numerical input feature (which is
consumed as DenseTensorSpec(float32) by default) will be feed
differently (e.g. RaggedTensor(int64)).
Returns:
A `History` object. Its `History.history` attribute is not yet
implemented for decision forests algorithms, and will return empty.
All other fields are filled as usual for `Keras.Mode.fit()`.
"""
self._time_begin_training = datetime.now()
if self._verbose >= 1:
tf_logging.info("Training model on dataset %s", train_path)
self._clear_function_cache()
# Call "compile" if the user forgot to do so.
if not self._is_compiled:
self.compile()
train_model_path = self._temp_directory
model_path = os.path.join(train_model_path, "model")
# Create the dataspec guide.
guide = data_spec_pb2.DataSpecificationGuide(
ignore_columns_without_guides=self._exclude_non_specified,
max_num_scanned_rows_to_accumulate_statistics=max_num_scanned_rows_to_accumulate_statistics
)
guide.default_column_guide.categorial.max_vocab_count = self._max_vocab_count
self._normalized_input_keys = []
for feature in self._features:
col_guide = copy.deepcopy(feature.guide)
col_guide.column_name_pattern = tf_core.normalize_inputs_regexp(
feature.name)
guide.column_guides.append(col_guide)
self._normalized_input_keys.append(feature.name)
label_guide = data_spec_pb2.ColumnGuide(
column_name_pattern=tf_core.normalize_inputs_regexp(label_key))
if self._task == Task.CLASSIFICATION:
label_guide.type = data_spec_pb2.CATEGORICAL
label_guide.categorial.min_vocab_frequency = 0
label_guide.categorial.max_vocab_count = -1
elif self._task == Task.REGRESSION:
label_guide.type = data_spec_pb2.NUMERICAL
elif self._task == Task.RANKING:
label_guide.type = data_spec_pb2.NUMERICAL
else:
raise ValueError(
f"Non implemented task {self._task} with \"fit_on_dataset_path\"."
" Use a different task or train with \"fit\".")
guide.column_guides.append(label_guide)
if ranking_key:
ranking_guide = data_spec_pb2.ColumnGuide(
column_name_pattern=tf_core.normalize_inputs_regexp(ranking_key),
type=data_spec_pb2.HASH)
guide.column_guides.append(ranking_guide)
if weight_key:
weight_guide = data_spec_pb2.ColumnGuide(
column_name_pattern=tf_core.normalize_inputs_regexp(weight_key),
type=data_spec_pb2.NUMERICAL)
guide.column_guides.append(weight_guide)
# Deployment configuration
deployment_config = copy.deepcopy(
self._advanced_arguments.yggdrasil_deployment_config)
if not deployment_config.HasField("num_threads"):
deployment_config.num_threads = self._num_threads
distribution_config = tf_core.get_distribution_configuration(
self.distribute_strategy)
if distribution_config is not None and not self.capabilities(
).support_partial_cache_dataset_format:
raise ValueError(
f"The model {type(self)} does not support training with a TF "
"Distribution strategy (i.e. model.capabilities()."
"support_partial_cache_dataset_format == False). If the dataset "
"is small, simply remove the distribution strategy scope (i.e. `with "
"strategy.scope():` around the model construction). If the dataset "
"is large, use a distributed version of the model. For Example, use "
"DistributedGradientBoostedTreesModel instead of "
"GradientBoostedTreesModel.")
with tf_logging.capture_cpp_log_context(verbose=self._verbose >= 2):
# Train the model.
tf_core.train_on_file_dataset(
train_dataset_path=dataset_format + ":" + train_path,
valid_dataset_path=(dataset_format + ":" +
valid_path) if valid_path else None,
feature_ids=self._normalized_input_keys,
label_id=label_key,
weight_id=weight_key,
model_id=self._training_model_id,
model_dir=train_model_path,
learner=self._learner,
task=self._task,
generic_hparms=tf_core.hparams_dict_to_generic_proto(
self._learner_params),
ranking_group=ranking_key,
keep_model_in_resource=True,
guide=guide,
training_config=self._advanced_arguments.yggdrasil_training_config,
deployment_config=deployment_config,
working_cache_path=os.path.join(self._temp_directory,
"working_cache"),
distribution_config=distribution_config,
try_resume_training=try_resume_training)
self._time_end_training = datetime.now()
if self._verbose >= 1:
self._print_timer_training()
if self._verbose >= 1:
tf_logging.info("Compiling model")
# Request and store a description of the model.
self._description = training_op.SimpleMLShowModel(
model_identifier=self._training_model_id).numpy().decode("utf-8")
training_op.SimpleMLUnloadModel(model_identifier=self._training_model_id)
# Build the model's graph.
inspector = inspector_lib.make_inspector(model_path)
self._set_from_yggdrasil_model(
inspector,
model_path,
input_model_signature_fn=input_model_signature_fn)
# Build the model history.
history = tf.keras.callbacks.History()
history.model = self
history.on_train_begin()
training_logs = inspector.training_logs()
if training_logs is not None:
for src_logs in training_logs:
if src_logs.evaluation is not None:
history.on_epoch_end(src_logs.num_trees,
src_logs.evaluation.to_dict())
self.history = history
return self.history