def convert_variables_to_constants_v2()

in tensorflow/tensorflow/python/framework/convert_to_constants.py [0:0]


def convert_variables_to_constants_v2(func, lower_control_flow=True):
  """Replaces all the variables in a graph with constants of the same values.

  TensorFlow 2.0 function for converting all Variable ops into Const ops holding
  the same values. This makes it possible to describe the network fully with a
  single GraphDef file, and allows the removal of a lot of ops related to
  loading and saving the variables. This function runs Grappler's function
  inlining optimization in order to return a single subgraph.

  The current implementation only works for graphs that do not contain any
  control flow or embedding related ops.

  Args:
    func: ConcreteFunction.
    lower_control_flow: Boolean indicating whether or not to lower control flow
      ops such as If and While. (default True)

  Returns:
    ConcreteFunction containing a simplified version of the original.
  """
  # TODO(nupurgarg): Replace ResourceGather with Gather.
  # Inline the graph in order to remove functions when possible.
  graph_def = _run_inline_graph_optimization(func, lower_control_flow)

  # Gets list of all node defs include those in the library.
  node_defs = _get_node_defs_list(graph_def)

  # Get mapping from node name to node.
  name_to_node = {_get_tensor_name(node.name): node for node in node_defs}

  # Get mapping from node name to variable value.
  tensor_data = _get_tensor_data(func)

  # Get mapping from function name to argument types.
  function_data = _get_control_flow_function_data(node_defs, tensor_data)

  # Get variable data for all nodes in `node_defs`.
  reference_variables = {}
  resource_identities = {}
  placeholders = {}
  converted_input_indices = set()

  def _save_placeholder(node_name, dtype):
    placeholders[node_name] = {
        "dtype": dtype,
        "data": tensor_data[node_name]["data"],
    }
    converted_input_indices.add(tensor_data[node_name]["index"])

  for node in node_defs:
    if node.op in _CONDITIONAL_OPS:
      # Get dtype and data for resource Placeholders.
      then_func = node.attr["then_branch"].func.name
      arg_types = function_data[then_func]["types"]
      for idx, input_tensor in enumerate(node.input[1:]):
        input_name = _get_tensor_name(input_tensor)
        if input_name in tensor_data:
          dtype = attr_value_pb2.AttrValue(type=arg_types[idx])
          _save_placeholder(_get_tensor_name(input_tensor), dtype)
    elif node.op in _LOOP_OPS:
      # Get dtype and data for resource Placeholders.
      cond_func = node.attr["cond"].func.name
      arg_types = function_data[cond_func]["types"]
      for idx, input_tensor in enumerate(node.input):
        input_name = _get_tensor_name(input_tensor)
        if input_name in tensor_data:
          dtype = attr_value_pb2.AttrValue(type=arg_types[idx])
          _save_placeholder(_get_tensor_name(input_tensor), dtype)
    elif (node.op == "Identity" and node.attr["T"].type == dtypes.resource and
          name_to_node[_get_tensor_name(node.input[0])].op in _LOOP_OPS):
      # Store the dtype for Identity resource ops that are outputs of While ops.
      while_node = name_to_node[_get_tensor_name(node.input[0])]
      body_func = while_node.attr["body"].func.name
      input_data = node.input[0].split(":")
      idx = 0 if len(input_data) == 1 else int(input_data[1])

      dtype = attr_value_pb2.AttrValue(
          type=function_data[body_func]["types"][idx])
      resource_identities[node.name] = dtype
    elif node.op == "VariableV2":
      # Get data for VariableV2 ops (reference variables) that cannot be lifted.
      with func.graph.as_default():
        identity_node = array_ops.identity(
            func.graph.as_graph_element(node.name + ":0"))
      reference_variables[node.name] = (
          func.prune([], [identity_node.name])()[0])
    elif node.name in tensor_data and not tensor_data[node.name]["is_variable"]:
      # Get dtype and data for non-variable Placeholders (ex. values for 1.X
      # Const ops that are loaded as Placeholders in 2.0)
      _save_placeholder(node.name, node.attr["dtype"])
    elif node.op == "ReadVariableOp":
      # Get dtype and data for Placeholder ops associated with ReadVariableOp.
      # There can be an Identity in between the ReadVariableOp and Placeholder.
      # Store the dtype for the Identity ops.
      input_name = _get_tensor_name(node.input[0])
      while name_to_node[input_name].op == "Identity":
        resource_identities[input_name] = node.attr["dtype"]
        input_name = _get_tensor_name(name_to_node[input_name].input[0])
      if name_to_node[input_name].op != "Placeholder":
        raise ValueError("Cannot find the Placeholder op that is an input "
                         "to the ReadVariableOp.")
      _save_placeholder(input_name, node.attr["dtype"])

  # Reconstruct the graph with constants in place of variables.
  output_graph_def = graph_pb2.GraphDef()

  for input_node in graph_def.node:
    output_node = output_graph_def.node.add()
    # Convert VariableV2 ops to Const ops.
    if input_node.name in reference_variables:
      data = reference_variables[input_node.name]
      dtype = attr_value_pb2.AttrValue(type=data.dtype.as_datatype_enum)
      _populate_const_op(output_node, input_node.name, dtype, data.numpy(),
                         data.shape)
    # Convert Placeholder ops to Const ops.
    elif input_node.name in placeholders:
      data = placeholders[input_node.name]["data"]
      dtype = placeholders[input_node.name]["dtype"]
      _populate_const_op(output_node, input_node.name, dtype, data, data.shape)
    # Update the dtype for Identity ops that are inputs to ReadVariableOps.
    elif input_node.name in resource_identities:
      output_node.CopyFrom(input_node)
      output_node.attr["T"].CopyFrom(resource_identities[input_node.name])
    # Convert ReadVariableOps to Identity ops.
    elif input_node.op == "ReadVariableOp":
      _populate_identity_op(output_node, input_node)
    # Update the function names and argument types for the conditional ops.
    elif input_node.op in _CONDITIONAL_OPS:
      _populate_if_op(output_node, input_node, function_data)
    elif input_node.op in _LOOP_OPS:
      _populate_while_op(output_node, input_node, function_data)
    else:
      output_node.CopyFrom(input_node)

  # Add functions to reconstructed graph.
  if graph_def.library:
    library = output_graph_def.library

    for input_library_func in graph_def.library.function:
      orig_func_name = input_library_func.signature.name
      new_func_name = _get_new_function_name(orig_func_name)

      # Do not copy any functions that aren't being used in the graph. Any
      # functions that are not used by control flow should have been inlined.
      if orig_func_name not in function_data:
        continue

      output_library_func = library.function.add()
      for key, value in input_library_func.ret.items():
        output_library_func.ret[key] = value
      for key, value in input_library_func.control_ret.items():
        output_library_func.control_ret[key] = value

      # Update the input types in the function signature. Update the output
      # types for functions that are while loop bodies.
      output_library_func.signature.CopyFrom(input_library_func.signature)
      output_library_func.signature.name = new_func_name
      for dtype, arg in zip(function_data[orig_func_name]["types"],
                            output_library_func.signature.input_arg):
        arg.type = dtype
      if function_data[orig_func_name]["is_also_output_type"]:
        for dtype, arg in zip(function_data[orig_func_name]["types"],
                              output_library_func.signature.output_arg):
          arg.type = dtype

      # Update the NodeDefs.
      func_variables = {
          node.name: node.input[0]
          for node in input_library_func.node_def
          if node.op == "ReadVariableOp"
      }

      for input_node in input_library_func.node_def:
        output_node = output_library_func.node_def.add()
        # Convert ReadVariableOps to Identity ops.
        if input_node.op == "ReadVariableOp":
          _populate_identity_op(output_node, input_node)
        # Update the function names and argument types for the conditional ops.
        elif input_node.op in _CONDITIONAL_OPS:
          _populate_if_op(output_node, input_node, function_data)
        elif input_node.op in _LOOP_OPS:
          _populate_while_op(output_node, input_node, function_data)
        else:
          output_node.CopyFrom(input_node)
          # Convert :value to :output for ops that use the ReadVariableOp.
          for idx, full_name in enumerate(input_node.input):
            input_name = _get_tensor_name(full_name)
            if input_name in func_variables:
              full_name_parts = full_name.split(":")
              full_name_parts[1] = "output"
              input_name = ":".join(full_name_parts)
              output_node.input[idx] = input_name

  output_graph_def.versions.CopyFrom(graph_def.versions)
  return _construct_concrete_function(func, output_graph_def,
                                      converted_input_indices)