in twml/twml/trainers/trainer.py [0:0]
def train_and_evaluate(self, train_input_fn=None, eval_input_fn=None,
train_max_steps=None, eval_steps=None,
eval_delay=None, eval_period=None,
train_hooks=None, eval_hooks=None,
early_stop_metric=None, early_stop_patience=-1,
early_stop_minimize=True, early_stop_tolerance=0, exporters=None,
export_output_fn=None, max_duration=None):
"""
Train and evaluate the estimator for ``train_max_steps``
using ``tf.estimator.train_and_evaluate``.
With a cluster configuration provided in the ``TF_CONFIG`` environment variable, this method
can be used for distributed training (multi-node or multi-process).
Unlike the ``learn`` method, training is continuous with ``train_max_steps``.
For distributed use case, evaluation happens periodically.
That is, after ``eval_delay`` seconds, an evaluation epoch of ``eval_step`` steps
occurs every ``eval_period`` seconds. Evaluation happens on the most recent checkpoint.
TF defaults to saving checkpoints every 10 mins.
For local use case, training occurs for train_max_steps epochs followed by a
single evaluation. For local use case we therefore recommend using learn() instead
as it provides early-stopping and multiple evaluations.
``train_and_evaluate`` will evaluate for ``eval_steps`` every ``eval_period`` seconds.
It will stop after ``train_steps`` is reached.
You must ensure that all workers/servers are assigned the same `save_dir`.
.. Note::
If the TF_CONFIG environment variable is set, this function assumes its running a distribute job.
Args:
train_input_fn:
Function to iterate through training set. It is passed to estimator.train_and_evalute
eval_input_fn:
Function to iterate through evaluation set. It is passed to estimator.train_and_evalute.
train_max_steps:
maximum number of global steps of training to run.
Defaults to params.train_max_steps.
Non-positive values and None-values train indefinitely (use with caution).
eval_steps:
number of steps per evaluation.
Defaults to params.eval_steps.
Non-positive values and None-values go through
the entire evaluation set for each evaluation.
Note that the number of eval_steps should be high enough to minimize noise.
This is especially true for early-stopping.
eval_delay:
Start the first evaluation after eval_delay. Defaults to params.eval_delay or 2*60s.
eval_period:
Run an evaluation every eval_period seconds. Defaults to params.eval_period or 10*60s.
exporters:
List of exporters called at the end of each evaluation run.
Defaults to none.
export_output_fn:
The output format to use for exported models.
Only used if exporters is not None.
Early-stopping arguments:
early_stop_metric:
String specifying the metric to early-stop on. Required with positive
``early_stop_patience``. For example, 'accuracy', 'accuracy_0', 'loss', etc.
The string is used to extract the relevant tensor Op from the dict returned by
the get_eval_metric_ops method. For ``metrics`` pass to the constructor,
the string is one of those. For multi-class (that is, multi-metric)
metrics, the string may be appended with a ``_0``, ``_1``, etc. or one
of the ``multi_metric_names`` (one per class).
early_stop_patience:
Maximum number of epochs to wait for an improvement in the early_stop_metric
before breaking off training. For example, a patience of 10 means that
training will have 10 epochs to improve the metric before it is killed.
Whenever the metric is improved before running out of patience,
patience is reset to ``early_stop_patience``.
Defaults to -1 (that is, no early-stopping).
early_stop_minimize:
Set this to True (the default) for metrics that need to be minimized
(like ``loss``). Metrics like ``accuracy`` that need to be maximized
should set this to False.
early_stop_tolerance:
A non-negative tolerance for comparing early_stop_metric.
E.g. when maximizing the condition is current_metric > best_metric + tolerance.
Defaults to 0.
max_duration:
A float. When this argument is defined, the job will automatically terminate after
`max_duration` seconds if it has not already compeleted.
Returns:
The directory where the checkpoints were saved.
"""
logging.info("WARNING: Trainer.train_and_evaluate is an EXPERIMENTAL API.")
logging.info("Trainer.train_and_evaluate may change or be removed in future versions.")
if not callable(train_input_fn):
raise ValueError("Expecting callable train_input_fn function")
if not callable(eval_input_fn):
raise ValueError("Expecting callable eval_input_fn function")
self._exit_ps_after_training_complete()
# Maybe export in eval processes.
if self.is_evaluator():
if self.params.get("eval_name") is not None:
# Do not export if running special eval.
exporters = None
export_output_fn = None
elif exporters and export_output_fn:
self._export_output_fn = export_output_fn
else:
# Default option.
self._export_output_fn = None
train_hooks = self.get_train_hooks() if train_hooks is None else train_hooks
train_hooks = [] if train_hooks is None else train_hooks
eval_hooks = self.get_eval_hooks() if eval_hooks is None else eval_hooks
eval_hooks = [] if eval_hooks is None else eval_hooks
if train_max_steps is None:
train_max_steps = self.params.get('train_max_steps')
if eval_steps is None:
eval_steps = self.params.eval_steps
if eval_steps <= 0:
eval_steps = None
if eval_delay is None:
eval_delay = self.params.eval_delay
if eval_period is None:
eval_period = self.params.eval_period
if early_stop_patience > 0:
# when training hooks detect this file, they request a stop to training
early_stop_path = os.path.join(self._save_dir, 'earlystop_now.txt')
# prepare early stopping hook (which also handles logic here)
self._is_early_stopping = True
eval_early_stop_hook = twml.hooks.EarlyStopHook(
metric=early_stop_metric,
checkpoint_dir=self._save_dir,
patience=early_stop_patience,
minimize=early_stop_minimize,
tolerance=early_stop_tolerance,
get_estimator_spec_fn=lambda: self.current_estimator_spec,
file_path=early_stop_path,
exit_on_end=os.environ.get('TF_CONFIG') is not None) # only exit for distributed jobs
# add early stop hook to eval hooks
eval_hooks.append(eval_early_stop_hook)
# prepare the commensurate training hook
train_early_stop_hook = twml.hooks.StopIfExistsHook(early_stop_path)
train_hooks.append(train_early_stop_hook)
if max_duration is not None:
train_early_stop_duration_hook = twml.hooks.EarlyStopDuration(
max_duration=max_duration,
exit_on_end=False,
save_dir=self._save_dir,
overwrite=self.is_chief()
)
eval_early_stop_duration_hook = twml.hooks.EarlyStopDuration(
max_duration=max_duration,
exit_on_end=os.environ.get('TF_CONFIG') is not None,
save_dir=self._save_dir,
overwrite=False
) # only exit for distributed jobs
train_hooks.append(train_early_stop_duration_hook)
eval_hooks.append(eval_early_stop_duration_hook)
with self.experiment_tracker.track_experiment(eval_hooks, lambda: self.current_estimator_spec):
train_spec = self.get_train_spec(train_input_fn, train_max_steps, train_hooks)
eval_spec = self.get_eval_spec(eval_input_fn, eval_steps,
eval_delay, eval_period,
eval_hooks, exporters)
self._train_and_evaluate(train_spec, eval_spec)
if self.is_chief():
self.write_state_to_disk(save_dir=self._save_dir, filename='_SUCCESS')
return self._save_dir