in tensorflow_estimator/python/estimator/estimator.py [0:0]
def _add_meta_graph_for_mode(self,
builder,
input_receiver_fn_map,
checkpoint_path,
save_variables=True,
mode=ModeKeys.PREDICT,
export_tags=None,
check_variables=True,
strip_default_attrs=True):
"""Loads variables and adds them along with a `tf.MetaGraphDef` for saving.
Args:
builder: instance of `tf.saved_modle.builder.SavedModelBuilder` that will
be used for saving.
input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to
`input_receiver_fn` mappings, where the `input_receiver_fn` is a
function that takes no argument and returns the appropriate subclass of
`InputReceiver`.
checkpoint_path: The checkpoint path to export.
save_variables: bool, whether variables should be saved. If `False`, just
the `tf.MetaGraphDef` will be saved. Note that `save_variables` should
only be `True` for the first call to this function, and the
`SavedModelBuilder` will raise an error if that is not the case.
mode: `tf.estimator.ModeKeys` value indicating which mode will be
exported.
export_tags: The set of tags with which to save `tf.MetaGraphDef`. If
`None`, a default set will be selected to matched the passed mode.
check_variables: bool, whether to check the checkpoint has all variables.
strip_default_attrs: bool, whether to strip default attributes. This may
only be True when called from the deprecated V1
Estimator.export_savedmodel.
Raises:
ValueError: if `save_variables` is `True` and `check_variable` is `False`.
"""
if export_tags is None:
export_tags = export_lib.EXPORT_TAG_MAP[mode]
input_receiver_fn = input_receiver_fn_map[mode]
with tf.Graph().as_default() as g:
self._create_and_assert_global_step(g)
tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)
input_receiver = input_receiver_fn()
# Call the model_fn and collect the export_outputs.
estimator_spec = self._call_model_fn(
features=input_receiver.features,
labels=getattr(input_receiver, 'labels', None),
mode=mode,
config=self.config)
export_outputs = export_lib.export_outputs_for_mode(
mode=estimator_spec.mode,
serving_export_outputs=estimator_spec.export_outputs,
predictions=estimator_spec.predictions,
loss=estimator_spec.loss,
metrics=estimator_spec.eval_metric_ops)
# Build the SignatureDefs from receivers and all outputs
signature_def_map = export_lib.build_all_signature_defs(
input_receiver.receiver_tensors,
export_outputs,
getattr(input_receiver, 'receiver_tensors_alternatives', None),
serving_only=(mode == ModeKeys.PREDICT))
with tf.compat.v1.Session(config=self._session_config) as session:
if estimator_spec.scaffold.local_init_op is not None:
local_init_op = estimator_spec.scaffold.local_init_op
else:
local_init_op = tf.compat.v1.train.Scaffold.default_local_init_op()
# This saver will be used both for restoring variables now,
# and in saving out the metagraph below. This ensures that any
# Custom Savers stored with the Scaffold are passed through to the
# SavedModel for restore later.
if isinstance(estimator_spec.scaffold.saver, trackable_util.Checkpoint):
graph_saver = tf.compat.v1.train.Saver(
var_list=graph_view.ObjectGraphView(
estimator_spec.scaffold.saver).frozen_saveable_objects(),
sharded=True)
else:
graph_saver = (
estimator_spec.scaffold.saver or
tf.compat.v1.train.Saver(sharded=True))
if save_variables and not check_variables:
raise ValueError('If `save_variables` is `True, `check_variables`'
'must not be `False`.')
if check_variables:
try:
graph_saver.restore(session, checkpoint_path)
except tf.errors.NotFoundError as e:
msg = ('Could not load all requested variables from checkpoint. '
'Please make sure your model_fn does not expect variables '
'that were not saved in the checkpoint.\n\n'
'Encountered error with mode `{}` while restoring '
'checkpoint from: `{}`. Full Traceback:\n\n{}').format(
mode, checkpoint_path, e)
raise ValueError(msg)
# We add the train op explicitly for now, so that we don't have to
# change the Builder public interface. Note that this is a no-op
# for prediction, where train_op is None.
builder._add_train_op(estimator_spec.train_op) # pylint: disable=protected-access
meta_graph_kwargs = dict(
tags=export_tags,
signature_def_map=signature_def_map,
assets_collection=tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.ASSET_FILEPATHS),
main_op=local_init_op,
saver=graph_saver,
strip_default_attrs=strip_default_attrs)
if save_variables:
builder.add_meta_graph_and_variables(session, **meta_graph_kwargs)
else:
builder.add_meta_graph(**meta_graph_kwargs)