def train_on_file_dataset()

in tensorflow_decision_forests/tensorflow/core.py [0:0]


def train_on_file_dataset(
    train_dataset_path: str,
    valid_dataset_path: Optional[str],
    feature_ids: List[str],
    label_id: str,
    weight_id: Optional[str],
    model_id: str,
    learner: str,
    task: Optional[TaskType] = Task.CLASSIFICATION,
    generic_hparms: Optional[
        abstract_learner_pb2.GenericHyperParameters] = None,
    ranking_group: Optional[str] = None,
    uplift_treatment: Optional[str] = None,
    training_config: Optional[abstract_learner_pb2.TrainingConfig] = None,
    deployment_config: Optional[abstract_learner_pb2.DeploymentConfig] = None,
    guide: Optional[data_spec_pb2.DataSpecificationGuide] = None,
    model_dir: Optional[str] = None,
    keep_model_in_resource: Optional[bool] = True,
    working_cache_path: Optional[str] = None,
    distribution_config: Optional[DistributionConfiguration] = None,
    try_resume_training: Optional[bool] = False) -> tf.Operation:
  """Trains a model on dataset stored on file.

  The input arguments and overall logic of this OP is similar to the ":train"
  CLI or the "learner->Train()" method of Yggdrasil Decision Forests (in fact,
  this OP simply calls "learner->Train()").

  Similarly as the `train` method, the implementation the learning algorithm
  should be added as a dependency to the binary. Similarly, the implementation
  the dataset format should be added as a dependency to the
  binary.

  In the case of distributed training, `train_on_file_dataset` should only be
  called by the `chief` process, and `deployment_config` should contain the
  address of the workers.

  Args:
    train_dataset_path: Path to the training dataset.
    valid_dataset_path: Path to the validation dataset.
    feature_ids: Ids/names of the input features.
    label_id: Id/name of the label feature.
    weight_id: Id/name of the weight feature.
    model_id: Id of the model.
    learner: Key of the learner.
    task: Task to solve.
    generic_hparms: Hyper-parameter of the learner.
    ranking_group: Id of the ranking group feature. Only for ranking.
    uplift_treatment: Id of the uplift treatment group feature. Only for uplift.
    training_config: Training configuration.
    deployment_config: Deployment configuration (e.g. where to train the model).
    guide: Dataset specification guide.
    model_dir: If specified, export the trained model into this directory.
    keep_model_in_resource: If true, keep the model as a training model
      resource.
    working_cache_path: Path to the working cache directory. If set, and if the
      training is distributed, all the workers should have write access to this
      cache.
    distribution_config: Socket addresses of the workers for distributed
      training.
    try_resume_training: Try to resume the training from the
      "working_cache_path" directory. The the "working_cache_path" does not
      contains any checkpoint, start the training from the start.

  Returns:
    The OP that trigger the training.
  """

  if generic_hparms is None:
    generic_hparms = abstract_learner_pb2.GenericHyperParameters()

  if training_config is None:
    training_config = abstract_learner_pb2.TrainingConfig()
  else:
    training_config = copy.deepcopy(training_config)

  if deployment_config is None:
    deployment_config = abstract_learner_pb2.DeploymentConfig()
  else:
    deployment_config = copy.deepcopy(deployment_config)

  if guide is None:
    guide = data_spec_pb2.DataSpecificationGuide()

  if ranking_group is not None:
    training_config.ranking_group = ranking_group

  if uplift_treatment is not None:
    training_config.uplift_treatment = uplift_treatment

  # Set the method argument into the proto configs.
  training_config.learner = learner
  training_config.task = task
  training_config.label = label_id

  if weight_id is not None:
    training_config.weight_definition.attribute = weight_id
    training_config.weight_definition.numerical.SetInParent()

  for feature_id in feature_ids:
    training_config.features.append(normalize_inputs_regexp(feature_id))

  if working_cache_path is not None:
    deployment_config.cache_path = working_cache_path

  if try_resume_training:
    if working_cache_path is None:
      raise ValueError("Cannot train a model with `try_resume_training=True` "
                       "without a working cache directory.")
    deployment_config.try_resume_training = True

  if distribution_config is not None:
    deployment_config.try_resume_training = True
    deployment_config.distribute.implementation_key = "TF_DIST"

    if distribution_config.workers_addresses is not None:
      dst_addresses = deployment_config.distribute.Extensions[
          tf_distribution_pb2.tf_distribution].addresses
      dst_addresses.addresses[:] = distribution_config.workers_addresses

    else:
      # Assume the worker paths are provided through the env.
      deployment_config.distribute.Extensions[
          tf_distribution_pb2.tf_distribution].environment_variable.SetInParent(
          )

  return training_op.SimpleMLModelTrainerOnFile(
      train_dataset_path=train_dataset_path,
      valid_dataset_path=valid_dataset_path if valid_dataset_path else "",
      model_id=model_id if keep_model_in_resource else "",
      model_dir=model_dir or "",
      hparams=generic_hparms.SerializeToString(),
      training_config=training_config.SerializeToString(),
      deployment_config=deployment_config.SerializeToString(),
      guide=guide.SerializeToString())