def _make_assets_key_collection()

in tensorflow_hub/saved_model_lib.py [0:0]


def _make_assets_key_collection(saved_model_proto, export_path):
  """Creates an ASSETS_KEY collection in the GraphDefs in saved_model_proto.

  Adds an ASSETS_KEY collection to the GraphDefs in the SavedModel and returns
  a map from original asset filename to filename when exporting the SavedModel
  to `export_path`.

  This is roughly the inverse operation of `_merge_assets_key_collection`.

  Args:
    saved_model_proto: SavedModel proto to be modified.
    export_path: string with path where the saved_model_proto will be exported.

  Returns:
    A map from original asset filename to asset filename when exporting the
    SavedModel to path.

  Raises:
    ValueError: on unsuported/unexpected SavedModel.
  """
  asset_filenames = {}
  used_asset_filenames = set()

  def _make_asset_filename(original_filename):
    """Returns the asset filename to use for the file."""
    if original_filename in asset_filenames:
      return asset_filenames[original_filename]

    basename = os.path.basename(original_filename)
    suggestion = basename
    index = 0
    while suggestion in used_asset_filenames:
      suggestion = tf.compat.as_bytes(basename) + tf.compat.as_bytes(str(index))
      index += 1
    asset_filenames[original_filename] = suggestion
    used_asset_filenames.add(suggestion)
    return suggestion

  for meta_graph in saved_model_proto.meta_graphs:
    collection_def = meta_graph.collection_def.get(
        tf.compat.v1.GraphKeys.ASSET_FILEPATHS)

    if collection_def is None:
      continue
    if collection_def.WhichOneof("kind") != "node_list":
      raise ValueError(
          "MetaGraph collection ASSET_FILEPATHS is not a list of tensors.")

    for tensor in collection_def.node_list.value:
      if not tensor.endswith(":0"):
        raise ValueError("Unexpected tensor in ASSET_FILEPATHS collection.")

    asset_nodes = set([
        _get_node_name_from_tensor(tensor)
        for tensor in collection_def.node_list.value
    ])

    tensor_filename_map = {}
    for node in meta_graph.graph_def.node:
      if node.name in asset_nodes:
        _check_asset_node_def(node)
        filename = node.attr["value"].tensor.string_val[0]
        tensor_filename_map[node.name + ":0"] = filename
        logging.debug("Found asset node %s pointing to %s", node.name, filename)
        # Clear value to avoid leaking the original path.
        node.attr["value"].tensor.string_val[0] = (
            tf.compat.as_bytes("SAVEDMODEL-ASSET"))

    if tensor_filename_map:
      assets_key_collection = meta_graph.collection_def[
          tf.compat.v1.saved_model.ASSETS_KEY]

      for tensor, filename in sorted(tensor_filename_map.items()):
        asset_proto = meta_graph_pb2.AssetFileDef()
        asset_proto.filename = _make_asset_filename(filename)
        asset_proto.tensor_info.name = tensor
        assets_key_collection.any_list.value.add().Pack(asset_proto)

  return {
      original_filename: _get_asset_filename(export_path, asset_filename)
      for original_filename, asset_filename in asset_filenames.items()
  }