in tensorflow_decision_forests/keras/core.py [0:0]
def _train_model(self):
"""Effectively train the model."""
if self._normalized_input_keys is None:
raise Exception("The training graph was not built.")
self._time_end_data_feed = datetime.now()
if self._verbose >= 1:
self._print_timer_feed_data()
tf_logging.info("Training model")
self._time_begin_training = datetime.now()
train_model_path = self._temp_directory
model_path = os.path.join(train_model_path, "model")
# Create the dataspec guide.
guide = data_spec_pb2.DataSpecificationGuide()
guide.default_column_guide.categorial.max_vocab_count = self._max_vocab_count
for feature in self._features:
col_guide = copy.deepcopy(feature.guide)
col_guide.column_name_pattern = tf_core.normalize_inputs_regexp(
feature.name)
guide.column_guides.append(col_guide)
# Deployment configuration
deployment_config = copy.deepcopy(
self._advanced_arguments.yggdrasil_deployment_config)
if not deployment_config.HasField("num_threads"):
deployment_config.num_threads = self._num_threads
distribution_config = tf_core.get_distribution_configuration(
self.distribute_strategy)
with tf_logging.capture_cpp_log_context(verbose=self._verbose >= 2):
if distribution_config is None:
# Train the model.
# The model will be exported to "train_model_path".
#
# Note: It would be possible to train and load the model without saving
# the model to file.
tf_core.train(
input_ids=self._normalized_input_keys,
label_id=_LABEL,
weight_id=_WEIGHTS if self._weighted_training else None,
model_id=self._training_model_id,
model_dir=train_model_path,
learner=self._learner,
task=self._task,
generic_hparms=tf_core.hparams_dict_to_generic_proto(
self._learner_params),
ranking_group=_RANK_GROUP if self._task == Task.RANKING else None,
uplift_treatment=_UPLIFT_TREATMENT
if self._task == Task.CATEGORICAL_UPLIFT else None,
keep_model_in_resource=True,
guide=guide,
training_config=self._advanced_arguments.yggdrasil_training_config,
deployment_config=deployment_config,
try_resume_training=self._try_resume_training,
has_validation_dataset=self._has_validation_dataset)
else:
tf_core.finalize_distributed_dataset_collection(
cluster_coordinator=self._cluster_coordinator,
input_ids=self._normalized_input_keys + [_LABEL] +
([_WEIGHTS] if self._weighted_training else []),
model_id=self._training_model_id,
dataset_path=self._distributed_partial_dataset_cache_path())
tf_core.train_on_file_dataset(
train_dataset_path="partial_dataset_cache:" +
self._distributed_partial_dataset_cache_path(),
valid_dataset_path=None,
feature_ids=self._normalized_input_keys,
label_id=_LABEL,
weight_id=_WEIGHTS if self._weighted_training else None,
model_id=self._training_model_id,
model_dir=train_model_path,
learner=self._learner,
task=self._task,
generic_hparms=tf_core.hparams_dict_to_generic_proto(
self._learner_params),
ranking_group=_RANK_GROUP if self._task == Task.RANKING else None,
uplift_treatment=_UPLIFT_TREATMENT
if self._task == Task.CATEGORICAL_UPLIFT else None,
keep_model_in_resource=True,
guide=guide,
training_config=self._advanced_arguments.yggdrasil_training_config,
deployment_config=deployment_config,
working_cache_path=os.path.join(self._temp_directory,
"working_cache"),
distribution_config=distribution_config,
try_resume_training=self._try_resume_training)
# Request and store a description of the model.
self._description = training_op.SimpleMLShowModel(
model_identifier=self._training_model_id).numpy().decode("utf-8")
training_op.SimpleMLUnloadModel(model_identifier=self._training_model_id)
self._is_trained.assign(True)
self._time_end_training = datetime.now()
if self._verbose >= 1:
self._print_timer_training()
tf_logging.info("Compiling model")
# Load and optimize the model in memory.
# Register the model as a SavedModel asset.
self._model = tf_op.ModelV2(model_path=model_path, verbose=False)