def _augment_model_fn()

in tensorflow_estimator/python/estimator/tpu/tpu_estimator.py [0:0]


  def _augment_model_fn(self, model_fn, batch_axis):
    """Returns a new model_fn, which wraps the TPU support."""

    def _model_fn(features, labels, mode, config, params):
      """A Estimator `model_fn` for TPUEstimator."""

      # `input_fn` is called in `train()`, `evaluate()`, and `predict()`,
      # but not in `export_saved_model()`.
      if self._is_input_fn_invoked:
        is_export_mode = False
      else:
        is_export_mode = True

      # Clear the bit.
      self._is_input_fn_invoked = None

      if is_export_mode:
        if mode == _INFERENCE_ON_TPU_MODE:
          _add_item_to_params(params, _USE_TPU_KEY, True)
          mode = model_fn_lib.ModeKeys.PREDICT
        else:
          _add_item_to_params(params, _USE_TPU_KEY, False)

      with self._ctx.with_mode(mode) as ctx:
        model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)

        # examples_hook is added to training_hooks for both CPU and TPU
        # execution.
        if (self._log_every_n_steps is not None or
            self._log_every_n_secs is not None):
          examples_hook = ExamplesPerSecondHook(
              ctx.global_batch_size,
              # pylint:disable=g-long-ternary
              output_dir=(self.model_dir
                          if not config or config.save_summary_steps else None),
              # pylint:enable=g-long-ternary
              every_n_steps=self._log_every_n_steps,
              every_n_secs=self._log_every_n_secs)

        if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
          tf.compat.v1.logging.info('Running %s on CPU/GPU', mode)
          estimator_spec = model_fn_wrapper.call_without_tpu(
              features, labels, is_export_mode=is_export_mode)
          if (self._log_every_n_steps is not None or
              self._log_every_n_secs is not None):
            estimator_spec = estimator_spec._replace(
                training_hooks=estimator_spec.training_hooks + (examples_hook,))
          return estimator_spec

        assert labels is None, '`labels` passed to `model_fn` must be `None`.'
        # TPUEstimator._call_input_fn passes `input_fn` as features to here.
        assert callable(features), '`input_fn` is not callable.'
        input_fn = features

        tpu_init_ops = []
        if ctx.embedding_config and mode == model_fn_lib.ModeKeys.TRAIN:
          dummy_table_variables, dummy_table_variables_init = (
              tpu_embedding_gradient.create_dummy_table_variables(
                  ctx.embedding_config.tpu_embedding))
          ctx.embedding_config.dummy_table_variables = dummy_table_variables
          tpu_init_ops.append(dummy_table_variables_init)

        input_holders = _InputPipeline(input_fn, batch_axis, ctx)
        enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (
            input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())

        graph = tf.compat.v1.get_default_graph()
        for enqueue_op in enqueue_ops:
          if isinstance(enqueue_op, list):
            graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op)
          else:
            graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)

        if mode == model_fn_lib.ModeKeys.TRAIN:
          compile_op, loss, host_call, scaffold_fn, training_hooks = (
              _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
          has_saver_hook = training_hooks and any(
              isinstance(hook, tf.compat.v1.train.CheckpointSaverHook)
              for hook in training_hooks)
          if ctx.embedding_config:
            g = tf.compat.v1.get_default_graph()
            table_to_config_dict = (
                ctx.embedding_config.tpu_embedding.table_to_config_dict)
            optimization_parameters = (
                ctx.embedding_config.tpu_embedding.optimization_parameters)
            if self._embedding_from_feature_columns:
              embedding_variable_name_by_table, slot_variable_names_by_table = (
                  _tpu_estimator_embedding.get_full_variable_names(
                      g, table_to_config_dict, optimization_parameters))
            else:
              embedding_variable_name_by_table = None
              slot_variable_names_by_table = None
            embedding_variables_and_ops = (
                ctx.embedding_config.tpu_embedding.create_variables_and_ops(
                    embedding_variable_name_by_table,
                    slot_variable_names_by_table))
            tpu_init_ops.extend(embedding_variables_and_ops.load_ops())
          # scaffold_fn must be called after variables for TPU embedding has
          # been created on CPU, as user might reinitialize those from some
          # checkpoint within scaffold_fn.
          scaffold = _get_scaffold(scaffold_fn)

          host_ops = host_call.create_tpu_hostcall()

          shutdown_hooks = []
          shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',
                                         'reset_computation')
          if shutdown_mode:
            if shutdown_mode == 'shutdown_worker':
              finalizer_hooks = [
                  session_support.ShutdownLameWorkers(),
              ]
            elif shutdown_mode == 'shutdown_all_workers':
              finalizer_hooks = [
                  session_support.ShutdownAllWorkers(),
              ]
            elif shutdown_mode == 'reset_computation':
              finalizer_hooks = [
                  session_support.ResetComputation(),
              ]
            elif not shutdown_mode:
              finalizer_hooks = []
            else:
              raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' %
                               shutdown_mode)

            if finalizer_hooks:
              if has_saver_hook:
                saver = _NotSaver(
                    'No save on shutdown when there are user-defined '
                    'CheckpointSaverHooks')
              else:
                saver = None  # Yes automatic save on shutdown.
              shutdown_hooks.append(
                  session_support.GracefulShutdownHook(
                      checkpoint_prefix=self.model_dir + '/model.ckpt',
                      on_shutdown_hooks=finalizer_hooks,
                      saver=saver))

          with tf.control_dependencies([loss]):
            global_step = tf.identity(tf.compat.v1.train.get_global_step())
          hooks = input_hooks + shutdown_hooks

          if ctx.feed_hook is not None:
            tf.compat.v1.logging.info(
                'Use user implemented tpu infeed outfeed session hook class.')
            infeed_outfeed_session_hook_class = ctx.feed_hook
          else:
            infeed_outfeed_session_hook_class = TPUInfeedOutfeedSessionHook

          hooks.extend([
              infeed_outfeed_session_hook_class(
                  ctx,
                  enqueue_ops,
                  host_ops,
                  tpu_compile_op=compile_op,
                  run_infeed_loop_on_coordinator=(
                      run_infeed_loop_on_coordinator),
                  rendezvous=self._rendezvous[mode],
                  master=self._config.master,
                  session_config=self._session_config,
                  tpu_init_ops=tpu_init_ops,
                  outfeed_every_n_steps=self._config.tpu_config
                  .experimental_host_call_every_n_steps),
              InstallSignalHandlerHook()
          ])
          if _check_add_preemption_hook(self._config.cluster):
            hooks.extend(
                [preempted_hook.CloudTPUPreemptedHook(self._config.cluster)])
          if (self._log_every_n_steps is not None or
              self._log_every_n_secs is not None):
            if self._iterations_per_training_loop.unit == 'count':
              examples_hook._set_steps_per_run(  # pylint: disable=protected-access
                  self._iterations_per_training_loop.value)
            hooks.append(
                tf.compat.v1.train.LoggingTensorHook(
                    {
                        'loss': tf.identity(loss),
                        'step': global_step,
                    },
                    every_n_iter=self._log_every_n_steps,
                    every_n_secs=self._log_every_n_secs))
            hooks.append(examples_hook)

          if training_hooks:
            hooks.extend(training_hooks)

          chief_hooks = []
          if (not has_saver_hook and
              (self._config.save_checkpoints_secs or
               self._config.save_checkpoints_steps)):
            checkpoint_hook = tf.compat.v1.train.CheckpointSaverHook(
                self.model_dir,
                save_secs=self._config.save_checkpoints_secs,
                save_steps=self._config.save_checkpoints_steps,
                scaffold=scaffold,
                save_graph_def=self._config.checkpoint_save_graph_def)
            if self._iterations_per_training_loop.unit == 'count':
              checkpoint_hook._set_steps_per_run(  # pylint: disable=protected-access
                  self._iterations_per_training_loop.value)
            chief_hooks.append(checkpoint_hook)
          else:
            tf.compat.v1.logging.info('Bypassing TPUEstimator hook')

          tf.compat.v1.summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
          with tf.control_dependencies([loss]):
            update_ops = _sync_variables_ops(ctx)
            if ctx.embedding_config:
              update_ops.extend(embedding_variables_and_ops.retrieve_ops())

          # Validate the TPU training graph to catch basic errors
          _validate_tpu_training_graph(ctx)

          train_op = tf.group(*update_ops)
          graph.add_to_collection(_TPU_TRAIN_OP, train_op)

          return model_fn_lib.EstimatorSpec(
              mode,
              loss=loss,
              training_chief_hooks=chief_hooks,
              training_hooks=hooks,
              train_op=train_op,
              scaffold=scaffold)

        if mode == model_fn_lib.ModeKeys.EVAL:
          compile_op, total_loss, host_calls, scaffold_fn, eval_hooks = (
              _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
          if ctx.embedding_config:
            g = tf.compat.v1.get_default_graph()
            table_to_config_dict = (
                ctx.embedding_config.tpu_embedding.table_to_config_dict)
            if self._embedding_from_feature_columns:
              embedding_variable_name_by_table, _ = (
                  _tpu_estimator_embedding.get_full_variable_names(
                      g, table_to_config_dict))
            else:
              embedding_variable_name_by_table = None
            embedding_variables_and_ops = (
                ctx.embedding_config.tpu_embedding.create_variables_and_ops(
                    embedding_variable_name_by_table))
            tpu_init_ops.extend(embedding_variables_and_ops.load_ops())
          # scaffold_fn must be called after variables for TPU embedding has
          # been created on CPU, as user might reinitialize those from some
          # checkpoint within scaffold_fn.
          scaffold = _get_scaffold(scaffold_fn)
          iterations_per_loop_var = _create_or_get_iterations_per_loop()
          mean_loss = tf.compat.v1.div(
              total_loss,
              tf.cast(iterations_per_loop_var, dtype=total_loss.dtype))

          with tf.control_dependencies([mean_loss]):
            # After TPU evaluation computation is done (the mean_loss tensor),
            # reads all variables back from TPU and updates the eval step
            # counter properly
            internal_ops_to_run = _sync_variables_ops(ctx)
            internal_ops_to_run.append(
                _increase_eval_step_op(iterations_per_loop_var))

          host_call_ret = host_calls.create_tpu_hostcall()
          eval_metric_ops = {}
          eval_update_ops = []

          eval_metrics = host_call_ret.get('eval_metrics', {})
          if eval_metrics:
            # Creates a dummy metric update_op for all metrics. Estimator
            # expects all metrics in `eval_metric_ops` have update_op and calls
            # them one by one. The real metric update_ops are invoked in a
            # separated thread. So, here give Estimator the dummy op for all
            # metrics.
            with tf.control_dependencies(internal_ops_to_run):
              dummy_update_op = tf.no_op()

            for k, v in eval_metrics.items():
              eval_metric_ops[k] = (v[0], dummy_update_op)
              eval_update_ops.append(v[1])
          else:
            # If no eval metrics are passed, create an identity node for the
            # loss and add `internal_ops_to_run` to its dependencies. So
            # `internal_ops_to_run` can be executed.
            with tf.control_dependencies(internal_ops_to_run):
              mean_loss = tf.identity(mean_loss)

          if 'host_call' not in host_call_ret:
            host_ops = []
          else:
            host_ops = host_call_ret['host_call']
          hooks = [
              TPUInfeedOutfeedSessionHook(
                  ctx,
                  enqueue_ops,
                  eval_update_ops + host_ops,
                  tpu_compile_op=compile_op,
                  run_infeed_loop_on_coordinator=(
                      run_infeed_loop_on_coordinator),
                  rendezvous=self._rendezvous[mode],
                  master=self._config.evaluation_master,
                  session_config=self._session_config,
                  tpu_init_ops=tpu_init_ops)
          ] + input_hooks

          if _check_add_preemption_hook(self._config.cluster):
            hooks.extend(
                [preempted_hook.CloudTPUPreemptedHook(self._config.cluster)])

          if eval_hooks:
            hooks.extend(eval_hooks)

          return model_fn_lib.EstimatorSpec(
              mode,
              loss=mean_loss,
              evaluation_hooks=hooks,
              eval_metric_ops=eval_metric_ops,
              scaffold=scaffold)

        # Predict
        assert mode == model_fn_lib.ModeKeys.PREDICT

        (compile_op, dummy_predict_op, host_calls, scaffold_fn,
         prediction_hooks) = _predict_on_tpu_system(ctx, model_fn_wrapper,
                                                    dequeue_fn)
        scaffold = _get_scaffold(scaffold_fn)
        with tf.control_dependencies([dummy_predict_op]):
          internal_ops_to_run = _sync_variables_ops(ctx)
          with tf.control_dependencies(internal_ops_to_run):
            dummy_predict_op = tf.no_op()

        # In train and evaluation, the main TPU program is passed to monitored
        # training session to run. Infeed enqueue and outfeed dequeue are
        # executed in side threads. This is not the configuration for
        # prediction mode.
        #
        # For prediction, the Estimator executes the EstimatorSpec.predictions
        # directly and yield the element (via generator) to call site. So, the
        # outfeed based prediction must be passed to MonitoredSession directly.
        # Other parts of the TPU execution are organized as follows.
        #
        # 1. All outfeed based Tensors must be grouped with predictions Tensors
        #    to form a single invocation. This avoid the issue we might trigger
        #    multiple outfeeds incorrectly. To achieve this, `host_call` is
        #    placed in control_dependencies of `stopping_signals`, and
        #    `stopping_signals` is passed into _StoppingPredictHook, which sets
        #    the `stopping_signals` as SessionRunArgs. MonitoredSession merges
        #    all SessionRunArgs with the fetch in session.run together.
        #
        # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue)
        #    are grouped together. They will be launched once and only once in
        #    side threads and they quit naturally according to the SAME stopping
        #    condition.
        enqueue_ops.append(dummy_predict_op)

        host_call_ret = host_calls.create_tpu_hostcall()
        if 'host_call' not in host_call_ret:
          host_ops = []
        else:
          host_ops = host_call_ret['host_call']

        predictions = host_call_ret['predictions']
        _verify_cross_hosts_transfer_size(
            predictions,
            message=(
                'The estimated size for TPUEstimatorSpec.predictions is too '
                'large.'))
        signals = host_call_ret['signals']

        with tf.control_dependencies(host_ops):
          host_ops = []  # Empty, we do do not need it anymore.
          scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal(
              signals)
          predictions = _PaddingSignals.slice_tensor_or_dict(
              predictions, signals)

        hooks = [
            _StoppingPredictHook(scalar_stopping_signal),
            TPUInfeedOutfeedSessionHookForPrediction(
                ctx,
                enqueue_ops,
                host_ops,
                rendezvous=self._rendezvous[mode],
                tpu_compile_op=compile_op,
                master=self._config.master,
                session_config=self._session_config),
        ] + input_hooks

        if prediction_hooks:
          hooks.extend(prediction_hooks)

        return model_fn_lib.EstimatorSpec(
            mode,
            prediction_hooks=hooks,
            predictions=predictions,
            scaffold=scaffold)

    return _model_fn