in tensorflow_hub/saved_model_lib.py [0:0]
def _make_assets_key_collection(saved_model_proto, export_path):
"""Creates an ASSETS_KEY collection in the GraphDefs in saved_model_proto.
Adds an ASSETS_KEY collection to the GraphDefs in the SavedModel and returns
a map from original asset filename to filename when exporting the SavedModel
to `export_path`.
This is roughly the inverse operation of `_merge_assets_key_collection`.
Args:
saved_model_proto: SavedModel proto to be modified.
export_path: string with path where the saved_model_proto will be exported.
Returns:
A map from original asset filename to asset filename when exporting the
SavedModel to path.
Raises:
ValueError: on unsuported/unexpected SavedModel.
"""
asset_filenames = {}
used_asset_filenames = set()
def _make_asset_filename(original_filename):
"""Returns the asset filename to use for the file."""
if original_filename in asset_filenames:
return asset_filenames[original_filename]
basename = os.path.basename(original_filename)
suggestion = basename
index = 0
while suggestion in used_asset_filenames:
suggestion = tf.compat.as_bytes(basename) + tf.compat.as_bytes(str(index))
index += 1
asset_filenames[original_filename] = suggestion
used_asset_filenames.add(suggestion)
return suggestion
for meta_graph in saved_model_proto.meta_graphs:
collection_def = meta_graph.collection_def.get(
tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
if collection_def is None:
continue
if collection_def.WhichOneof("kind") != "node_list":
raise ValueError(
"MetaGraph collection ASSET_FILEPATHS is not a list of tensors.")
for tensor in collection_def.node_list.value:
if not tensor.endswith(":0"):
raise ValueError("Unexpected tensor in ASSET_FILEPATHS collection.")
asset_nodes = set([
_get_node_name_from_tensor(tensor)
for tensor in collection_def.node_list.value
])
tensor_filename_map = {}
for node in meta_graph.graph_def.node:
if node.name in asset_nodes:
_check_asset_node_def(node)
filename = node.attr["value"].tensor.string_val[0]
tensor_filename_map[node.name + ":0"] = filename
logging.debug("Found asset node %s pointing to %s", node.name, filename)
# Clear value to avoid leaking the original path.
node.attr["value"].tensor.string_val[0] = (
tf.compat.as_bytes("SAVEDMODEL-ASSET"))
if tensor_filename_map:
assets_key_collection = meta_graph.collection_def[
tf.compat.v1.saved_model.ASSETS_KEY]
for tensor, filename in sorted(tensor_filename_map.items()):
asset_proto = meta_graph_pb2.AssetFileDef()
asset_proto.filename = _make_asset_filename(filename)
asset_proto.tensor_info.name = tensor
assets_key_collection.any_list.value.add().Pack(asset_proto)
return {
original_filename: _get_asset_filename(export_path, asset_filename)
for original_filename, asset_filename in asset_filenames.items()
}