in adanet/core/estimator.py [0:0]
def train(self,
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None):
# pyformat: disable
"""Trains a model given training data :code:`input_fn`.
NOTE: If a given input_fn raises an :code:`OutOfRangeError`, then *all* of
training will exit. The best practice is to make the training dataset repeat
forever, in order to perform model search for more than one iteration.
Args:
input_fn: A function that provides input data for training as minibatches.
See [Premade Estimators](
https://tensorflow.org/guide/premade_estimators#create_input_functions)
for more information. The function should construct and return one of
the following:
* A :code:`tf.data.Dataset` object: Outputs of `Dataset` object must
be a tuple `(features, labels)` with same constraints as below.
* A tuple `(features, labels)`: Where `features` is a
:code:`tf.Tensor` or a dictionary of string feature name to
`Tensor` and `labels` is a :code:`Tensor` or a dictionary of string
label name to `Tensor`. Both `features` and `labels` are consumed by
`model_fn`. They should satisfy the expectation of `model_fn` from
inputs.
hooks: List of :code:`tf.train.SessionRunHook` subclass instances. Used
for callbacks inside the training loop.
steps: Number of steps for which to train the model. If :code:`None`,
train forever or train until `input_fn` generates the
:code:`tf.errors.OutOfRange` error or :code:`StopIteration` exception.
`steps` works incrementally. If you call two times `train(steps=10)`
then training occurs in total 20 steps. If :code:`OutOfRange` or
:code:`StopIteration` occurs in the middle, training stops before 20
steps. If you don't want to have incremental behavior please set
`max_steps` instead. If set, `max_steps` must be :code:`None`.
max_steps: Number of total steps for which to train model. If
:code:`None`, train forever or train until `input_fn` generates the
:code:`tf.errors.OutOfRange` error or :code:`StopIteration` exception.
If set, `steps` must be `None`. If :code:`OutOfRange` or
:code:`StopIteration` occurs in the middle, training stops before
`max_steps` steps. Two calls to `train(steps=100)` means 200 training
iterations. On the other hand, two calls to `train(max_steps=100)`
means that the second call will not do any iteration since first call
did all 100 steps.
saving_listeners: list of :code:`CheckpointSaverListener` objects. Used
for callbacks that run immediately before or after checkpoint savings.
Returns:
`self`, for chaining.
Raises:
ValueError: If both `steps` and `max_steps` are not `None`.
ValueError: If either `steps` or `max_steps <= 0`.
"""
# pyformat: enable
if (steps is not None) and (max_steps is not None):
raise ValueError("Can not provide both steps and max_steps.")
if steps is not None and steps <= 0:
raise ValueError("Must specify steps > 0, given: {}".format(steps))
latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
latest_global_steps = self._checkpoint_global_step(latest_checkpoint)
if steps is not None:
max_steps = latest_global_steps + steps
# Each iteration of this AdaNet loop represents an `_Iteration`. The
# current iteration number is stored as a variable in the checkpoint so
# that training can be stopped and started at anytime.
with monkey_patch_default_variable_placement_strategy():
while True:
latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
latest_global_steps = self._checkpoint_global_step(latest_checkpoint)
current_iteration = self._checkpoint_iteration_number(latest_checkpoint)
logging.info("Beginning training AdaNet iteration %s",
current_iteration)
self._iteration_ended = False
# Delegate training to a temporary estimator instead of super to make
# passing arguments more functional (via params).
temp_estimator = self._create_temp_estimator(
config=self.config,
is_inside_training_loop=True,
checkpoint_path=latest_checkpoint,
hooks=hooks)
result = temp_estimator.train(
input_fn=input_fn,
hooks=hooks,
max_steps=max_steps,
saving_listeners=saving_listeners)
# In TensorFlow v2.0.0.rc1 and below, saving listeners are attached to
# the first CheckpointSaverHook each time train is called. Instead, we
# pass in the saving_listeners in the first AdaNet iteration only.
if not tf_compat.version_greater_or_equal("2.0.0.rc1"):
saving_listeners = None
logging.info("Finished training Adanet iteration %s", current_iteration)
# If training ended because the maximum number of training steps
# occurred, exit training.
latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
global_steps = self._checkpoint_global_step(latest_checkpoint)
if max_steps is not None and global_steps >= max_steps:
logging.info("Training ended after %s global steps", global_steps)
return result
# If training ended for any reason other than the iteration ending,
# exit training.
if not self._iteration_ended:
logging.info("Training stop requested")
return result
max_iterations = self._max_iterations
if max_iterations and current_iteration + 1 >= max_iterations:
logging.info(
"Training ended after exceeding maximum AdaNet iterations")
if steps is not None and global_steps - latest_global_steps < steps:
logging.warning(
"Both `max_iterations` and `steps` were specified, but "
"`max_iterations` takes precedence over `steps`")
return result
logging.info("Beginning bookkeeping phase for iteration %s",
current_iteration)
# The chief prepares the next AdaNet iteration, and increments the
# iteration number by 1.
if self.config.is_chief:
with self._force_replication_strategy():
self._execute_bookkeeping_phase(
input_fn,
current_iteration,
train_hooks=hooks or [],
checkpoint_path=latest_checkpoint)
# This inner loop serves mainly for synchronizing the workers with the
# chief during distributed training. Workers that finish training early
# wait for the chief to prepare the next iteration and increment the
# iteration number. Workers that are slow to finish training quickly
# move onto the next iteration. And workers that go offline and return
# online after training ended terminate gracefully.
wait_for_chief = not self.config.is_chief
timer = _CountDownTimer(self._worker_wait_timeout_secs)
while wait_for_chief:
# Fetch the latest checkpoint.
latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
# If the chief hits max_steps, it will stop training itself and not
# increment the iteration number, so this is how the worker knows to
# exit if it wakes up and the chief is gone.
# TODO: Support steps parameter.
if self._checkpoint_global_step(latest_checkpoint) >= max_steps:
return result
# In distributed training, a worker may end training before the chief
# overwrites the checkpoint with the incremented iteration number. If
# that is the case, it should wait for the chief to do so. Otherwise
# the worker will get stuck waiting for its weights to be initialized.
next_iteration = self._checkpoint_iteration_number(latest_checkpoint)
if next_iteration > current_iteration:
break
logging.info("Iteration number in latest checkpoint: %d",
next_iteration)
# Check timeout when waiting for potentially downed chief.
if timer.secs_remaining() == 0:
logging.error(
"Chief job did not prepare iteration %d after %s secs. It "
"may have been preempted, been turned down, or crashed. This "
"worker is now exiting training.", current_iteration + 1,
self._worker_wait_timeout_secs)
return result
logging.info("Waiting for chief to prepare iteration %d",
current_iteration + 1)
time.sleep(self._worker_wait_secs)
# Stagger starting workers to prevent training instability.
# Mimics behavior of tf.estimator.train_and_evaluate.
if not self.config.is_chief and self.config.task_type == "worker":
task_id = self.config.task_id or 0
# Stagger each worker up to 60 secs.
delay_secs = min(self._max_worker_delay_secs,
(task_id + 1.) * self._delay_secs_per_worker)
if delay_secs > 0.:
logging.info("Waiting %d secs before continuing training.",
delay_secs)
time.sleep(delay_secs)
logging.info("Finished bookkeeping phase for iteration %s",
current_iteration)