def _train_model()

in tensorflow_decision_forests/keras/core.py [0:0]


  def _train_model(self):
    """Effectively train the model."""

    if self._normalized_input_keys is None:
      raise Exception("The training graph was not built.")

    self._time_end_data_feed = datetime.now()
    if self._verbose >= 1:
      self._print_timer_feed_data()
      tf_logging.info("Training model")

    self._time_begin_training = datetime.now()

    train_model_path = self._temp_directory
    model_path = os.path.join(train_model_path, "model")

    # Create the dataspec guide.
    guide = data_spec_pb2.DataSpecificationGuide()
    guide.default_column_guide.categorial.max_vocab_count = self._max_vocab_count
    for feature in self._features:
      col_guide = copy.deepcopy(feature.guide)
      col_guide.column_name_pattern = tf_core.normalize_inputs_regexp(
          feature.name)
      guide.column_guides.append(col_guide)

    # Deployment configuration
    deployment_config = copy.deepcopy(
        self._advanced_arguments.yggdrasil_deployment_config)
    if not deployment_config.HasField("num_threads"):
      deployment_config.num_threads = self._num_threads

    distribution_config = tf_core.get_distribution_configuration(
        self.distribute_strategy)

    with tf_logging.capture_cpp_log_context(verbose=self._verbose >= 2):

      if distribution_config is None:
        # Train the model.
        # The model will be exported to "train_model_path".
        #
        # Note: It would be possible to train and load the model without saving
        # the model to file.
        tf_core.train(
            input_ids=self._normalized_input_keys,
            label_id=_LABEL,
            weight_id=_WEIGHTS if self._weighted_training else None,
            model_id=self._training_model_id,
            model_dir=train_model_path,
            learner=self._learner,
            task=self._task,
            generic_hparms=tf_core.hparams_dict_to_generic_proto(
                self._learner_params),
            ranking_group=_RANK_GROUP if self._task == Task.RANKING else None,
            uplift_treatment=_UPLIFT_TREATMENT
            if self._task == Task.CATEGORICAL_UPLIFT else None,
            keep_model_in_resource=True,
            guide=guide,
            training_config=self._advanced_arguments.yggdrasil_training_config,
            deployment_config=deployment_config,
            try_resume_training=self._try_resume_training,
            has_validation_dataset=self._has_validation_dataset)

      else:
        tf_core.finalize_distributed_dataset_collection(
            cluster_coordinator=self._cluster_coordinator,
            input_ids=self._normalized_input_keys + [_LABEL] +
            ([_WEIGHTS] if self._weighted_training else []),
            model_id=self._training_model_id,
            dataset_path=self._distributed_partial_dataset_cache_path())

        tf_core.train_on_file_dataset(
            train_dataset_path="partial_dataset_cache:" +
            self._distributed_partial_dataset_cache_path(),
            valid_dataset_path=None,
            feature_ids=self._normalized_input_keys,
            label_id=_LABEL,
            weight_id=_WEIGHTS if self._weighted_training else None,
            model_id=self._training_model_id,
            model_dir=train_model_path,
            learner=self._learner,
            task=self._task,
            generic_hparms=tf_core.hparams_dict_to_generic_proto(
                self._learner_params),
            ranking_group=_RANK_GROUP if self._task == Task.RANKING else None,
            uplift_treatment=_UPLIFT_TREATMENT
            if self._task == Task.CATEGORICAL_UPLIFT else None,
            keep_model_in_resource=True,
            guide=guide,
            training_config=self._advanced_arguments.yggdrasil_training_config,
            deployment_config=deployment_config,
            working_cache_path=os.path.join(self._temp_directory,
                                            "working_cache"),
            distribution_config=distribution_config,
            try_resume_training=self._try_resume_training)

      # Request and store a description of the model.
      self._description = training_op.SimpleMLShowModel(
          model_identifier=self._training_model_id).numpy().decode("utf-8")
      training_op.SimpleMLUnloadModel(model_identifier=self._training_model_id)

      self._is_trained.assign(True)

      self._time_end_training = datetime.now()

      if self._verbose >= 1:
        self._print_timer_training()
        tf_logging.info("Compiling model")

      # Load and optimize the model in memory.
      # Register the model as a SavedModel asset.
      self._model = tf_op.ModelV2(model_path=model_path, verbose=False)