def recover_partitioned_variable_map()

in tensorflow_hub/native_module.py [0:0]


def recover_partitioned_variable_map(var_node_map):
  """Builds a proper variable map if it contains PartitionedVariables.

  Args:
    var_node_map: A map to tf.Variables. PartitionedVariables show up in this
      map as N entries with keys "<var_name>/part_n".

  Returns:
    A map to tf.Variables or to list of tf.Variables for each
    PartitionedVariables in `var_node_map`.

  Raises:
    RuntimeError: if there are issues recovering the PartitionedVariables.
  """
  offset_variables_map = {}
  for var_key, var_tensor in var_node_map.items():
    match, var_name, offset = _extract_variable_parts(var_key, var_tensor)

    if not match:
      # This is a standard variable, so we can safely add it to the output.
      if var_key in offset_variables_map:
        raise RuntimeError(
            "Variable %s exists both as a single and partitioned variable.")
      offset_variables_map[var_key] = var_tensor
      continue

    if var_name not in offset_variables_map:
      offset_variables_map[var_name] = {}
    elif not isinstance(offset_variables_map[var_name], dict):
      raise RuntimeError(
          "Variable %s exists both as a single and partitioned variable.")

    # Duplicated variable offsets should not exist.
    if offset in offset_variables_map[var_name]:
      raise RuntimeError(
          "Variable map contains duplicate offset %d for variable [%s]" %
          (offset, var_name))
    offset_variables_map[var_name][offset] = var_tensor

  variables_map = {}
  # Use offsets for sorting, then strip them from the dictionary and keep only
  # a list of variables per each variable name.
  for var_name, var_value in offset_variables_map.items():
    if not isinstance(var_value, dict):
      variables_map[var_name] = var_value
      continue
    shapes = [var_tensor.shape[1:] for var_tensor in var_value.values()]
    if not all(shape == shapes[0] for shape in shapes):
      raise RuntimeError("Shapes not compatible: %s" % (shapes))
    for _, tensor in sorted(var_value.items()):
      variables_map[var_name] = [
          tensor for _, tensor in sorted(var_value.items())
      ]

  return variables_map