in tensorflow_estimator/python/estimator/estimator.py [0:0]
def _actual_train_model_distributed(self, strategy, input_fn, hooks,
saving_listeners):
"""That method that does actual training with distribution strategy."""
# TODO(sourabhbajaj): Remove this hack once we migrate the other strategies
# to use the new API
is_tpu_strategy = strategy.__class__.__name__.startswith('TPUStrategy')
worker_hooks = []
with tf.Graph().as_default() as g:
# We want to create the iterations variable outside the distribution scope
# as that is just stored on the host and mainly used to drive the loop
# and doesn't need to be a Mirrored/Device variable.
if is_tpu_strategy:
steps_per_run_variable = training.get_or_create_steps_per_run_variable()
# Set flag on the distribution strategy so that optimizer v1 is
# distribution aware and scales the losses by number of replicas.
# This is required only for backward compatibility with estimator and
# V1 optimizer. TF2 will not do this scaling.
if hasattr(strategy, '_scale_loss_for_estimator_enabled'):
scale_ctx = strategy._scale_loss_for_estimator_enabled() # pylint: disable=protected-access
else:
# TODO(psv): Remove this clause after estimator repo gets the
# distribute library changes related to loss scaling.
@tf_contextlib.contextmanager
def nullcontextmanager():
yield
scale_ctx = nullcontextmanager()
with strategy.scope(), scale_ctx:
tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)
iterator, input_hooks = self._get_iterator_from_input_fn(
input_fn, ModeKeys.TRAIN, strategy)
worker_hooks.extend(input_hooks)
global_step_tensor = self._create_and_assert_global_step(g)
# we want to add to the global collection in the main thread not the
# replica threads.
tf.compat.v1.add_to_collection(
training_util.GLOBAL_STEP_READ_KEY,
strategy.extended.read_var(global_step_tensor))
if is_tpu_strategy:
# Create a step_fn from the train_op of grouped_estimator_spec
def step_fn(ctx, inputs):
"""A single step that is passed to run_on_dataset."""
if isinstance(inputs, tuple):
features, labels = inputs
else:
features = inputs
labels = None
estimator_spec = strategy.extended.call_for_each_replica(
self._call_model_fn,
args=(features, labels, ModeKeys.TRAIN, self.config))
ctx.set_last_step_output(
name='loss',
output=estimator_spec.loss,
reduce_op=_get_loss_reduce_op_for_reporting())
ctx.set_non_tensor_output(
name='estimator_spec', output=estimator_spec)
return estimator_spec.train_op
# Create new train_op post graph rewrites
initial_training_loss = tf.constant(1e7)
ctx = strategy.extended.experimental_run_steps_on_iterator(
step_fn,
iterator,
iterations=steps_per_run_variable,
initial_loop_values={'loss': initial_training_loss})
distributed_train_op = ctx.run_op
loss = ctx.last_step_outputs['loss']
grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
else:
features, labels = estimator_util.parse_iterator_result(
iterator.get_next())
grouped_estimator_spec = strategy.extended.call_for_each_replica(
self._call_model_fn,
args=(
features,
labels, # although this will be None it seems
ModeKeys.TRAIN,
self.config))
loss = strategy.reduce(
_get_loss_reduce_op_for_reporting(),
grouped_estimator_spec.loss,
axis=None)
distributed_train_op = grouped_estimator_spec.train_op
scaffold = _combine_distributed_scaffold(
grouped_estimator_spec.scaffold, strategy)
# TODO(yuefengz): add a test for unwrapping per_device_hooks.
def get_hooks_from_the_first_device(per_device_hooks):
return [
self._train_distribution.experimental_local_results(
per_device_hook)[0] for per_device_hook in per_device_hooks
]
training_hooks = get_hooks_from_the_first_device(
grouped_estimator_spec.training_hooks)
training_chief_hooks = get_hooks_from_the_first_device(
grouped_estimator_spec.training_chief_hooks)
estimator_spec = model_fn_lib.EstimatorSpec(
mode=grouped_estimator_spec.mode,
loss=loss,
train_op=strategy.group(distributed_train_op),
training_hooks=training_hooks,
training_chief_hooks=training_chief_hooks,
scaffold=scaffold)
return self._train_with_estimator_spec(estimator_spec, worker_hooks,
hooks, global_step_tensor,
saving_listeners)