in tensorflow_hub/native_module.py [0:0]
def recover_partitioned_variable_map(var_node_map):
"""Builds a proper variable map if it contains PartitionedVariables.
Args:
var_node_map: A map to tf.Variables. PartitionedVariables show up in this
map as N entries with keys "<var_name>/part_n".
Returns:
A map to tf.Variables or to list of tf.Variables for each
PartitionedVariables in `var_node_map`.
Raises:
RuntimeError: if there are issues recovering the PartitionedVariables.
"""
offset_variables_map = {}
for var_key, var_tensor in var_node_map.items():
match, var_name, offset = _extract_variable_parts(var_key, var_tensor)
if not match:
# This is a standard variable, so we can safely add it to the output.
if var_key in offset_variables_map:
raise RuntimeError(
"Variable %s exists both as a single and partitioned variable.")
offset_variables_map[var_key] = var_tensor
continue
if var_name not in offset_variables_map:
offset_variables_map[var_name] = {}
elif not isinstance(offset_variables_map[var_name], dict):
raise RuntimeError(
"Variable %s exists both as a single and partitioned variable.")
# Duplicated variable offsets should not exist.
if offset in offset_variables_map[var_name]:
raise RuntimeError(
"Variable map contains duplicate offset %d for variable [%s]" %
(offset, var_name))
offset_variables_map[var_name][offset] = var_tensor
variables_map = {}
# Use offsets for sorting, then strip them from the dictionary and keep only
# a list of variables per each variable name.
for var_name, var_value in offset_variables_map.items():
if not isinstance(var_value, dict):
variables_map[var_name] = var_value
continue
shapes = [var_tensor.shape[1:] for var_tensor in var_value.values()]
if not all(shape == shapes[0] for shape in shapes):
raise RuntimeError("Shapes not compatible: %s" % (shapes))
for _, tensor in sorted(var_value.items()):
variables_map[var_name] = [
tensor for _, tensor in sorted(var_value.items())
]
return variables_map