def _add_meta_graph_for_mode()

in tensorflow_estimator/python/estimator/estimator.py [0:0]


  def _add_meta_graph_for_mode(self,
                               builder,
                               input_receiver_fn_map,
                               checkpoint_path,
                               save_variables=True,
                               mode=ModeKeys.PREDICT,
                               export_tags=None,
                               check_variables=True,
                               strip_default_attrs=True):
    """Loads variables and adds them along with a `tf.MetaGraphDef` for saving.

    Args:
      builder: instance of `tf.saved_modle.builder.SavedModelBuilder` that will
        be used for saving.
      input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to
        `input_receiver_fn` mappings, where the `input_receiver_fn` is a
        function that takes no argument and returns the appropriate subclass of
        `InputReceiver`.
      checkpoint_path: The checkpoint path to export.
      save_variables: bool, whether variables should be saved. If `False`, just
        the `tf.MetaGraphDef` will be saved. Note that `save_variables` should
        only be `True` for the first call to this function, and the
        `SavedModelBuilder` will raise an error if that is not the case.
      mode: `tf.estimator.ModeKeys` value indicating which mode will be
        exported.
      export_tags: The set of tags with which to save `tf.MetaGraphDef`. If
        `None`, a default set will be selected to matched the passed mode.
      check_variables: bool, whether to check the checkpoint has all variables.
      strip_default_attrs: bool, whether to strip default attributes. This may
        only be True when called from the deprecated V1
        Estimator.export_savedmodel.

    Raises:
      ValueError: if `save_variables` is `True` and `check_variable` is `False`.
    """
    if export_tags is None:
      export_tags = export_lib.EXPORT_TAG_MAP[mode]
    input_receiver_fn = input_receiver_fn_map[mode]

    with tf.Graph().as_default() as g:
      self._create_and_assert_global_step(g)
      tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)

      input_receiver = input_receiver_fn()

      # Call the model_fn and collect the export_outputs.
      estimator_spec = self._call_model_fn(
          features=input_receiver.features,
          labels=getattr(input_receiver, 'labels', None),
          mode=mode,
          config=self.config)

      export_outputs = export_lib.export_outputs_for_mode(
          mode=estimator_spec.mode,
          serving_export_outputs=estimator_spec.export_outputs,
          predictions=estimator_spec.predictions,
          loss=estimator_spec.loss,
          metrics=estimator_spec.eval_metric_ops)

      # Build the SignatureDefs from receivers and all outputs
      signature_def_map = export_lib.build_all_signature_defs(
          input_receiver.receiver_tensors,
          export_outputs,
          getattr(input_receiver, 'receiver_tensors_alternatives', None),
          serving_only=(mode == ModeKeys.PREDICT))

      with tf.compat.v1.Session(config=self._session_config) as session:

        if estimator_spec.scaffold.local_init_op is not None:
          local_init_op = estimator_spec.scaffold.local_init_op
        else:
          local_init_op = tf.compat.v1.train.Scaffold.default_local_init_op()

        # This saver will be used both for restoring variables now,
        # and in saving out the metagraph below. This ensures that any
        # Custom Savers stored with the Scaffold are passed through to the
        # SavedModel for restore later.
        if isinstance(estimator_spec.scaffold.saver, trackable_util.Checkpoint):
          graph_saver = tf.compat.v1.train.Saver(
              var_list=graph_view.ObjectGraphView(
                  estimator_spec.scaffold.saver).frozen_saveable_objects(),
              sharded=True)
        else:
          graph_saver = (
              estimator_spec.scaffold.saver or
              tf.compat.v1.train.Saver(sharded=True))

        if save_variables and not check_variables:
          raise ValueError('If `save_variables` is `True, `check_variables`'
                           'must not be `False`.')
        if check_variables:
          try:
            graph_saver.restore(session, checkpoint_path)
          except tf.errors.NotFoundError as e:
            msg = ('Could not load all requested variables from checkpoint. '
                   'Please make sure your model_fn does not expect variables '
                   'that were not saved in the checkpoint.\n\n'
                   'Encountered error with mode `{}` while restoring '
                   'checkpoint from: `{}`. Full Traceback:\n\n{}').format(
                       mode, checkpoint_path, e)
            raise ValueError(msg)

        # We add the train op explicitly for now, so that we don't have to
        # change the Builder public interface. Note that this is a no-op
        # for prediction, where train_op is None.
        builder._add_train_op(estimator_spec.train_op)  # pylint: disable=protected-access

        meta_graph_kwargs = dict(
            tags=export_tags,
            signature_def_map=signature_def_map,
            assets_collection=tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.ASSET_FILEPATHS),
            main_op=local_init_op,
            saver=graph_saver,
            strip_default_attrs=strip_default_attrs)

        if save_variables:
          builder.add_meta_graph_and_variables(session, **meta_graph_kwargs)
        else:
          builder.add_meta_graph(**meta_graph_kwargs)