in tensorflow_decision_forests/tensorflow/core.py [0:0]
def train_on_file_dataset(
train_dataset_path: str,
valid_dataset_path: Optional[str],
feature_ids: List[str],
label_id: str,
weight_id: Optional[str],
model_id: str,
learner: str,
task: Optional[TaskType] = Task.CLASSIFICATION,
generic_hparms: Optional[
abstract_learner_pb2.GenericHyperParameters] = None,
ranking_group: Optional[str] = None,
uplift_treatment: Optional[str] = None,
training_config: Optional[abstract_learner_pb2.TrainingConfig] = None,
deployment_config: Optional[abstract_learner_pb2.DeploymentConfig] = None,
guide: Optional[data_spec_pb2.DataSpecificationGuide] = None,
model_dir: Optional[str] = None,
keep_model_in_resource: Optional[bool] = True,
working_cache_path: Optional[str] = None,
distribution_config: Optional[DistributionConfiguration] = None,
try_resume_training: Optional[bool] = False) -> tf.Operation:
"""Trains a model on dataset stored on file.
The input arguments and overall logic of this OP is similar to the ":train"
CLI or the "learner->Train()" method of Yggdrasil Decision Forests (in fact,
this OP simply calls "learner->Train()").
Similarly as the `train` method, the implementation the learning algorithm
should be added as a dependency to the binary. Similarly, the implementation
the dataset format should be added as a dependency to the
binary.
In the case of distributed training, `train_on_file_dataset` should only be
called by the `chief` process, and `deployment_config` should contain the
address of the workers.
Args:
train_dataset_path: Path to the training dataset.
valid_dataset_path: Path to the validation dataset.
feature_ids: Ids/names of the input features.
label_id: Id/name of the label feature.
weight_id: Id/name of the weight feature.
model_id: Id of the model.
learner: Key of the learner.
task: Task to solve.
generic_hparms: Hyper-parameter of the learner.
ranking_group: Id of the ranking group feature. Only for ranking.
uplift_treatment: Id of the uplift treatment group feature. Only for uplift.
training_config: Training configuration.
deployment_config: Deployment configuration (e.g. where to train the model).
guide: Dataset specification guide.
model_dir: If specified, export the trained model into this directory.
keep_model_in_resource: If true, keep the model as a training model
resource.
working_cache_path: Path to the working cache directory. If set, and if the
training is distributed, all the workers should have write access to this
cache.
distribution_config: Socket addresses of the workers for distributed
training.
try_resume_training: Try to resume the training from the
"working_cache_path" directory. The the "working_cache_path" does not
contains any checkpoint, start the training from the start.
Returns:
The OP that trigger the training.
"""
if generic_hparms is None:
generic_hparms = abstract_learner_pb2.GenericHyperParameters()
if training_config is None:
training_config = abstract_learner_pb2.TrainingConfig()
else:
training_config = copy.deepcopy(training_config)
if deployment_config is None:
deployment_config = abstract_learner_pb2.DeploymentConfig()
else:
deployment_config = copy.deepcopy(deployment_config)
if guide is None:
guide = data_spec_pb2.DataSpecificationGuide()
if ranking_group is not None:
training_config.ranking_group = ranking_group
if uplift_treatment is not None:
training_config.uplift_treatment = uplift_treatment
# Set the method argument into the proto configs.
training_config.learner = learner
training_config.task = task
training_config.label = label_id
if weight_id is not None:
training_config.weight_definition.attribute = weight_id
training_config.weight_definition.numerical.SetInParent()
for feature_id in feature_ids:
training_config.features.append(normalize_inputs_regexp(feature_id))
if working_cache_path is not None:
deployment_config.cache_path = working_cache_path
if try_resume_training:
if working_cache_path is None:
raise ValueError("Cannot train a model with `try_resume_training=True` "
"without a working cache directory.")
deployment_config.try_resume_training = True
if distribution_config is not None:
deployment_config.try_resume_training = True
deployment_config.distribute.implementation_key = "TF_DIST"
if distribution_config.workers_addresses is not None:
dst_addresses = deployment_config.distribute.Extensions[
tf_distribution_pb2.tf_distribution].addresses
dst_addresses.addresses[:] = distribution_config.workers_addresses
else:
# Assume the worker paths are provided through the env.
deployment_config.distribute.Extensions[
tf_distribution_pb2.tf_distribution].environment_variable.SetInParent(
)
return training_op.SimpleMLModelTrainerOnFile(
train_dataset_path=train_dataset_path,
valid_dataset_path=valid_dataset_path if valid_dataset_path else "",
model_id=model_id if keep_model_in_resource else "",
model_dir=model_dir or "",
hparams=generic_hparms.SerializeToString(),
training_config=training_config.SerializeToString(),
deployment_config=deployment_config.SerializeToString(),
guide=guide.SerializeToString())