def apply_saved_model()

in tensorflow_transform/pretrained_models.py [0:0]


def apply_saved_model(model_dir, inputs, tags, signature_name=None,
                      output_keys_in_signature=None):
  """Applies a SavedModel to some `Tensor`s.

  Applies a SavedModel to `inputs`. The SavedModel is specified with
  `model_dir`, `tags` and `signature_name`. Note that the SavedModel will be
  converted to an all-constants graph.

  Note: This API can only be used when TF2 is disabled or
  `tft_beam.Context.force_tf_compat_v1=True`.

  Args:
    model_dir: A path containing a SavedModel.
    inputs: A dict whose keys are the names from the input signature and whose
        values are `Tensor`s. If there is only one input in the model's input
        signature then `inputs` can be a single `Tensor`.
    tags: The tags specifying which metagraph to load from the SavedModel.
    signature_name: Specify signature of the loaded model. The default value
        None can be used if there is only one signature in the MetaGraphDef.
    output_keys_in_signature: A list of strings which should be a subset of
        the outputs in the signature of the SavedModel. The returned `Tensor`s
        will correspond to specified output `Tensor`s, in the same order. The
        default value None can be used if there is only one output from
        signature.

  Returns:
    A `Tensor` or list of `Tensor`s representing the application of the
        SavedModel.

  Raises:
    ValueError: if
    `inputs` is invalid type, or
    `signature_name` is None but the SavedModel contains multiple signature, or
    `inputs` do not match the signature inputs, or
    `output_keys_in_signature` is not a subset of the signature outputs.
  """
  # Load model, get graph, inputs and outputs.
  loaded_graph = tf.compat.v1.Graph()
  loaded_initializer_op_names = []

  with loaded_graph.as_default():
    sess = tf.compat.v1.Session()
    meta_graph = tf.compat.v1.saved_model.load(sess,
                                               export_dir=model_dir,
                                               tags=tags)
    loaded_initializer_op_names = [
        op.name for op in tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)
    ]

    if signature_name:
      signature = meta_graph.signature_def[signature_name]
    elif len(meta_graph.signature_def) > 1:
      raise ValueError(
          'The SavedModel contains multiple signatures (%r) but signature_name '
          'was not specified.' % (meta_graph.signature_def.keys(),))
    else:
      signature = next(iter(meta_graph.signature_def.values()))

  # Generate mapping from tensors in the graph to the input tensors.
  if isinstance(inputs, dict):
    if set(signature.inputs.keys()) != set(inputs.keys()):
      raise ValueError(
          'The keys in `inputs` (%r) do not match inputs of the SavedModel '
          '(%r).' % (inputs.keys(), signature.inputs.keys()))
    input_name_to_tensor_map = {
        signature.inputs[key].name: inputs[key]
        for key in inputs.keys()}
  elif len(signature.inputs) != 1:
    raise ValueError(
        'The SavedModel does not have exactly one input (had inputs %r) but '
        '`inputs` was not a dict.' % (signature.inputs.keys(),))
  else:
    input_name_to_tensor_map = {
        next(iter(signature.inputs.values())).name: inputs
    }

  # Get output tensor names.
  if output_keys_in_signature:
    if not set(output_keys_in_signature) <= set(signature.outputs.keys()):
      raise ValueError(
          'output_keys_in_signature (%r) is not a subset of outputs of the '
          'SavedModel (%r).'
          % (output_keys_in_signature, signature.outputs.keys()))

    output_tensor_names = [
        signature.outputs[key].name for key in output_keys_in_signature
    ]
    output_single_tensor = False
  elif len(signature.outputs) != 1:
    raise ValueError(
        'The SavedModel does not have exactly one output (had outputs %r) but '
        'output_keys_in_signature was not specified.'
        % (signature.outputs.keys(),))
  else:
    output_tensor_names = [next(iter(signature.outputs.values())).name]
    output_single_tensor = True

  # Convert_variables_to_constants() requires op name.
  output_op_names = [loaded_graph.get_tensor_by_name(tensor_name).op.name
                     for tensor_name in output_tensor_names]
  constant_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
      sess, loaded_graph.as_graph_def(),
      output_op_names + loaded_initializer_op_names)
  sess.close()

  returned_elements = tf.import_graph_def(
      constant_graph_def,
      input_map=input_name_to_tensor_map,
      return_elements=output_tensor_names + loaded_initializer_op_names)
  returned_output_tensors = returned_elements[:len(output_tensor_names)]
  returned_initializer_ops = returned_elements[len(output_tensor_names):]

  for initializer_op in returned_initializer_ops:
    tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS,
                                   initializer_op)

  if output_single_tensor:
    assert len(output_tensor_names) == 1
    return returned_output_tensors[0]
  else:
    return returned_output_tensors