in tensorflow_transform/pretrained_models.py [0:0]
def apply_saved_model(model_dir, inputs, tags, signature_name=None,
output_keys_in_signature=None):
"""Applies a SavedModel to some `Tensor`s.
Applies a SavedModel to `inputs`. The SavedModel is specified with
`model_dir`, `tags` and `signature_name`. Note that the SavedModel will be
converted to an all-constants graph.
Note: This API can only be used when TF2 is disabled or
`tft_beam.Context.force_tf_compat_v1=True`.
Args:
model_dir: A path containing a SavedModel.
inputs: A dict whose keys are the names from the input signature and whose
values are `Tensor`s. If there is only one input in the model's input
signature then `inputs` can be a single `Tensor`.
tags: The tags specifying which metagraph to load from the SavedModel.
signature_name: Specify signature of the loaded model. The default value
None can be used if there is only one signature in the MetaGraphDef.
output_keys_in_signature: A list of strings which should be a subset of
the outputs in the signature of the SavedModel. The returned `Tensor`s
will correspond to specified output `Tensor`s, in the same order. The
default value None can be used if there is only one output from
signature.
Returns:
A `Tensor` or list of `Tensor`s representing the application of the
SavedModel.
Raises:
ValueError: if
`inputs` is invalid type, or
`signature_name` is None but the SavedModel contains multiple signature, or
`inputs` do not match the signature inputs, or
`output_keys_in_signature` is not a subset of the signature outputs.
"""
# Load model, get graph, inputs and outputs.
loaded_graph = tf.compat.v1.Graph()
loaded_initializer_op_names = []
with loaded_graph.as_default():
sess = tf.compat.v1.Session()
meta_graph = tf.compat.v1.saved_model.load(sess,
export_dir=model_dir,
tags=tags)
loaded_initializer_op_names = [
op.name for op in tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)
]
if signature_name:
signature = meta_graph.signature_def[signature_name]
elif len(meta_graph.signature_def) > 1:
raise ValueError(
'The SavedModel contains multiple signatures (%r) but signature_name '
'was not specified.' % (meta_graph.signature_def.keys(),))
else:
signature = next(iter(meta_graph.signature_def.values()))
# Generate mapping from tensors in the graph to the input tensors.
if isinstance(inputs, dict):
if set(signature.inputs.keys()) != set(inputs.keys()):
raise ValueError(
'The keys in `inputs` (%r) do not match inputs of the SavedModel '
'(%r).' % (inputs.keys(), signature.inputs.keys()))
input_name_to_tensor_map = {
signature.inputs[key].name: inputs[key]
for key in inputs.keys()}
elif len(signature.inputs) != 1:
raise ValueError(
'The SavedModel does not have exactly one input (had inputs %r) but '
'`inputs` was not a dict.' % (signature.inputs.keys(),))
else:
input_name_to_tensor_map = {
next(iter(signature.inputs.values())).name: inputs
}
# Get output tensor names.
if output_keys_in_signature:
if not set(output_keys_in_signature) <= set(signature.outputs.keys()):
raise ValueError(
'output_keys_in_signature (%r) is not a subset of outputs of the '
'SavedModel (%r).'
% (output_keys_in_signature, signature.outputs.keys()))
output_tensor_names = [
signature.outputs[key].name for key in output_keys_in_signature
]
output_single_tensor = False
elif len(signature.outputs) != 1:
raise ValueError(
'The SavedModel does not have exactly one output (had outputs %r) but '
'output_keys_in_signature was not specified.'
% (signature.outputs.keys(),))
else:
output_tensor_names = [next(iter(signature.outputs.values())).name]
output_single_tensor = True
# Convert_variables_to_constants() requires op name.
output_op_names = [loaded_graph.get_tensor_by_name(tensor_name).op.name
for tensor_name in output_tensor_names]
constant_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess, loaded_graph.as_graph_def(),
output_op_names + loaded_initializer_op_names)
sess.close()
returned_elements = tf.import_graph_def(
constant_graph_def,
input_map=input_name_to_tensor_map,
return_elements=output_tensor_names + loaded_initializer_op_names)
returned_output_tensors = returned_elements[:len(output_tensor_names)]
returned_initializer_ops = returned_elements[len(output_tensor_names):]
for initializer_op in returned_initializer_ops:
tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS,
initializer_op)
if output_single_tensor:
assert len(output_tensor_names) == 1
return returned_output_tensors[0]
else:
return returned_output_tensors