in tensorflow_estimator/python/estimator/tpu/tpu_estimator.py [0:0]
def _augment_model_fn(self, model_fn, batch_axis):
"""Returns a new model_fn, which wraps the TPU support."""
def _model_fn(features, labels, mode, config, params):
"""A Estimator `model_fn` for TPUEstimator."""
# `input_fn` is called in `train()`, `evaluate()`, and `predict()`,
# but not in `export_saved_model()`.
if self._is_input_fn_invoked:
is_export_mode = False
else:
is_export_mode = True
# Clear the bit.
self._is_input_fn_invoked = None
if is_export_mode:
if mode == _INFERENCE_ON_TPU_MODE:
_add_item_to_params(params, _USE_TPU_KEY, True)
mode = model_fn_lib.ModeKeys.PREDICT
else:
_add_item_to_params(params, _USE_TPU_KEY, False)
with self._ctx.with_mode(mode) as ctx:
model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
# examples_hook is added to training_hooks for both CPU and TPU
# execution.
if (self._log_every_n_steps is not None or
self._log_every_n_secs is not None):
examples_hook = ExamplesPerSecondHook(
ctx.global_batch_size,
# pylint:disable=g-long-ternary
output_dir=(self.model_dir
if not config or config.save_summary_steps else None),
# pylint:enable=g-long-ternary
every_n_steps=self._log_every_n_steps,
every_n_secs=self._log_every_n_secs)
if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
tf.compat.v1.logging.info('Running %s on CPU/GPU', mode)
estimator_spec = model_fn_wrapper.call_without_tpu(
features, labels, is_export_mode=is_export_mode)
if (self._log_every_n_steps is not None or
self._log_every_n_secs is not None):
estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks + (examples_hook,))
return estimator_spec
assert labels is None, '`labels` passed to `model_fn` must be `None`.'
# TPUEstimator._call_input_fn passes `input_fn` as features to here.
assert callable(features), '`input_fn` is not callable.'
input_fn = features
tpu_init_ops = []
if ctx.embedding_config and mode == model_fn_lib.ModeKeys.TRAIN:
dummy_table_variables, dummy_table_variables_init = (
tpu_embedding_gradient.create_dummy_table_variables(
ctx.embedding_config.tpu_embedding))
ctx.embedding_config.dummy_table_variables = dummy_table_variables
tpu_init_ops.append(dummy_table_variables_init)
input_holders = _InputPipeline(input_fn, batch_axis, ctx)
enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (
input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())
graph = tf.compat.v1.get_default_graph()
for enqueue_op in enqueue_ops:
if isinstance(enqueue_op, list):
graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op)
else:
graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)
if mode == model_fn_lib.ModeKeys.TRAIN:
compile_op, loss, host_call, scaffold_fn, training_hooks = (
_train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
has_saver_hook = training_hooks and any(
isinstance(hook, tf.compat.v1.train.CheckpointSaverHook)
for hook in training_hooks)
if ctx.embedding_config:
g = tf.compat.v1.get_default_graph()
table_to_config_dict = (
ctx.embedding_config.tpu_embedding.table_to_config_dict)
optimization_parameters = (
ctx.embedding_config.tpu_embedding.optimization_parameters)
if self._embedding_from_feature_columns:
embedding_variable_name_by_table, slot_variable_names_by_table = (
_tpu_estimator_embedding.get_full_variable_names(
g, table_to_config_dict, optimization_parameters))
else:
embedding_variable_name_by_table = None
slot_variable_names_by_table = None
embedding_variables_and_ops = (
ctx.embedding_config.tpu_embedding.create_variables_and_ops(
embedding_variable_name_by_table,
slot_variable_names_by_table))
tpu_init_ops.extend(embedding_variables_and_ops.load_ops())
# scaffold_fn must be called after variables for TPU embedding has
# been created on CPU, as user might reinitialize those from some
# checkpoint within scaffold_fn.
scaffold = _get_scaffold(scaffold_fn)
host_ops = host_call.create_tpu_hostcall()
shutdown_hooks = []
shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',
'reset_computation')
if shutdown_mode:
if shutdown_mode == 'shutdown_worker':
finalizer_hooks = [
session_support.ShutdownLameWorkers(),
]
elif shutdown_mode == 'shutdown_all_workers':
finalizer_hooks = [
session_support.ShutdownAllWorkers(),
]
elif shutdown_mode == 'reset_computation':
finalizer_hooks = [
session_support.ResetComputation(),
]
elif not shutdown_mode:
finalizer_hooks = []
else:
raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' %
shutdown_mode)
if finalizer_hooks:
if has_saver_hook:
saver = _NotSaver(
'No save on shutdown when there are user-defined '
'CheckpointSaverHooks')
else:
saver = None # Yes automatic save on shutdown.
shutdown_hooks.append(
session_support.GracefulShutdownHook(
checkpoint_prefix=self.model_dir + '/model.ckpt',
on_shutdown_hooks=finalizer_hooks,
saver=saver))
with tf.control_dependencies([loss]):
global_step = tf.identity(tf.compat.v1.train.get_global_step())
hooks = input_hooks + shutdown_hooks
if ctx.feed_hook is not None:
tf.compat.v1.logging.info(
'Use user implemented tpu infeed outfeed session hook class.')
infeed_outfeed_session_hook_class = ctx.feed_hook
else:
infeed_outfeed_session_hook_class = TPUInfeedOutfeedSessionHook
hooks.extend([
infeed_outfeed_session_hook_class(
ctx,
enqueue_ops,
host_ops,
tpu_compile_op=compile_op,
run_infeed_loop_on_coordinator=(
run_infeed_loop_on_coordinator),
rendezvous=self._rendezvous[mode],
master=self._config.master,
session_config=self._session_config,
tpu_init_ops=tpu_init_ops,
outfeed_every_n_steps=self._config.tpu_config
.experimental_host_call_every_n_steps),
InstallSignalHandlerHook()
])
if _check_add_preemption_hook(self._config.cluster):
hooks.extend(
[preempted_hook.CloudTPUPreemptedHook(self._config.cluster)])
if (self._log_every_n_steps is not None or
self._log_every_n_secs is not None):
if self._iterations_per_training_loop.unit == 'count':
examples_hook._set_steps_per_run( # pylint: disable=protected-access
self._iterations_per_training_loop.value)
hooks.append(
tf.compat.v1.train.LoggingTensorHook(
{
'loss': tf.identity(loss),
'step': global_step,
},
every_n_iter=self._log_every_n_steps,
every_n_secs=self._log_every_n_secs))
hooks.append(examples_hook)
if training_hooks:
hooks.extend(training_hooks)
chief_hooks = []
if (not has_saver_hook and
(self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps)):
checkpoint_hook = tf.compat.v1.train.CheckpointSaverHook(
self.model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=scaffold,
save_graph_def=self._config.checkpoint_save_graph_def)
if self._iterations_per_training_loop.unit == 'count':
checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access
self._iterations_per_training_loop.value)
chief_hooks.append(checkpoint_hook)
else:
tf.compat.v1.logging.info('Bypassing TPUEstimator hook')
tf.compat.v1.summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
with tf.control_dependencies([loss]):
update_ops = _sync_variables_ops(ctx)
if ctx.embedding_config:
update_ops.extend(embedding_variables_and_ops.retrieve_ops())
# Validate the TPU training graph to catch basic errors
_validate_tpu_training_graph(ctx)
train_op = tf.group(*update_ops)
graph.add_to_collection(_TPU_TRAIN_OP, train_op)
return model_fn_lib.EstimatorSpec(
mode,
loss=loss,
training_chief_hooks=chief_hooks,
training_hooks=hooks,
train_op=train_op,
scaffold=scaffold)
if mode == model_fn_lib.ModeKeys.EVAL:
compile_op, total_loss, host_calls, scaffold_fn, eval_hooks = (
_eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
if ctx.embedding_config:
g = tf.compat.v1.get_default_graph()
table_to_config_dict = (
ctx.embedding_config.tpu_embedding.table_to_config_dict)
if self._embedding_from_feature_columns:
embedding_variable_name_by_table, _ = (
_tpu_estimator_embedding.get_full_variable_names(
g, table_to_config_dict))
else:
embedding_variable_name_by_table = None
embedding_variables_and_ops = (
ctx.embedding_config.tpu_embedding.create_variables_and_ops(
embedding_variable_name_by_table))
tpu_init_ops.extend(embedding_variables_and_ops.load_ops())
# scaffold_fn must be called after variables for TPU embedding has
# been created on CPU, as user might reinitialize those from some
# checkpoint within scaffold_fn.
scaffold = _get_scaffold(scaffold_fn)
iterations_per_loop_var = _create_or_get_iterations_per_loop()
mean_loss = tf.compat.v1.div(
total_loss,
tf.cast(iterations_per_loop_var, dtype=total_loss.dtype))
with tf.control_dependencies([mean_loss]):
# After TPU evaluation computation is done (the mean_loss tensor),
# reads all variables back from TPU and updates the eval step
# counter properly
internal_ops_to_run = _sync_variables_ops(ctx)
internal_ops_to_run.append(
_increase_eval_step_op(iterations_per_loop_var))
host_call_ret = host_calls.create_tpu_hostcall()
eval_metric_ops = {}
eval_update_ops = []
eval_metrics = host_call_ret.get('eval_metrics', {})
if eval_metrics:
# Creates a dummy metric update_op for all metrics. Estimator
# expects all metrics in `eval_metric_ops` have update_op and calls
# them one by one. The real metric update_ops are invoked in a
# separated thread. So, here give Estimator the dummy op for all
# metrics.
with tf.control_dependencies(internal_ops_to_run):
dummy_update_op = tf.no_op()
for k, v in eval_metrics.items():
eval_metric_ops[k] = (v[0], dummy_update_op)
eval_update_ops.append(v[1])
else:
# If no eval metrics are passed, create an identity node for the
# loss and add `internal_ops_to_run` to its dependencies. So
# `internal_ops_to_run` can be executed.
with tf.control_dependencies(internal_ops_to_run):
mean_loss = tf.identity(mean_loss)
if 'host_call' not in host_call_ret:
host_ops = []
else:
host_ops = host_call_ret['host_call']
hooks = [
TPUInfeedOutfeedSessionHook(
ctx,
enqueue_ops,
eval_update_ops + host_ops,
tpu_compile_op=compile_op,
run_infeed_loop_on_coordinator=(
run_infeed_loop_on_coordinator),
rendezvous=self._rendezvous[mode],
master=self._config.evaluation_master,
session_config=self._session_config,
tpu_init_ops=tpu_init_ops)
] + input_hooks
if _check_add_preemption_hook(self._config.cluster):
hooks.extend(
[preempted_hook.CloudTPUPreemptedHook(self._config.cluster)])
if eval_hooks:
hooks.extend(eval_hooks)
return model_fn_lib.EstimatorSpec(
mode,
loss=mean_loss,
evaluation_hooks=hooks,
eval_metric_ops=eval_metric_ops,
scaffold=scaffold)
# Predict
assert mode == model_fn_lib.ModeKeys.PREDICT
(compile_op, dummy_predict_op, host_calls, scaffold_fn,
prediction_hooks) = _predict_on_tpu_system(ctx, model_fn_wrapper,
dequeue_fn)
scaffold = _get_scaffold(scaffold_fn)
with tf.control_dependencies([dummy_predict_op]):
internal_ops_to_run = _sync_variables_ops(ctx)
with tf.control_dependencies(internal_ops_to_run):
dummy_predict_op = tf.no_op()
# In train and evaluation, the main TPU program is passed to monitored
# training session to run. Infeed enqueue and outfeed dequeue are
# executed in side threads. This is not the configuration for
# prediction mode.
#
# For prediction, the Estimator executes the EstimatorSpec.predictions
# directly and yield the element (via generator) to call site. So, the
# outfeed based prediction must be passed to MonitoredSession directly.
# Other parts of the TPU execution are organized as follows.
#
# 1. All outfeed based Tensors must be grouped with predictions Tensors
# to form a single invocation. This avoid the issue we might trigger
# multiple outfeeds incorrectly. To achieve this, `host_call` is
# placed in control_dependencies of `stopping_signals`, and
# `stopping_signals` is passed into _StoppingPredictHook, which sets
# the `stopping_signals` as SessionRunArgs. MonitoredSession merges
# all SessionRunArgs with the fetch in session.run together.
#
# 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue)
# are grouped together. They will be launched once and only once in
# side threads and they quit naturally according to the SAME stopping
# condition.
enqueue_ops.append(dummy_predict_op)
host_call_ret = host_calls.create_tpu_hostcall()
if 'host_call' not in host_call_ret:
host_ops = []
else:
host_ops = host_call_ret['host_call']
predictions = host_call_ret['predictions']
_verify_cross_hosts_transfer_size(
predictions,
message=(
'The estimated size for TPUEstimatorSpec.predictions is too '
'large.'))
signals = host_call_ret['signals']
with tf.control_dependencies(host_ops):
host_ops = [] # Empty, we do do not need it anymore.
scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal(
signals)
predictions = _PaddingSignals.slice_tensor_or_dict(
predictions, signals)
hooks = [
_StoppingPredictHook(scalar_stopping_signal),
TPUInfeedOutfeedSessionHookForPrediction(
ctx,
enqueue_ops,
host_ops,
rendezvous=self._rendezvous[mode],
tpu_compile_op=compile_op,
master=self._config.master,
session_config=self._session_config),
] + input_hooks
if prediction_hooks:
hooks.extend(prediction_hooks)
return model_fn_lib.EstimatorSpec(
mode,
prediction_hooks=hooks,
predictions=predictions,
scaffold=scaffold)
return _model_fn