def _preprocess_graph()

in scripts/tf_cnn_benchmarks/benchmark_cnn.py [0:0]


  def _preprocess_graph(self, graph, graph_info):
    """Preprocess the graph before executing.

    Depending on the params, it runs various preprocessing on the graph,
    including freezing, TensorRT conversion, etc.

    Args:
      graph: the graph to preprocess.
      graph_info: the namedtuple returned by _build_graph() which
        contains all necessary information to benchmark the graph, including
        named tensors/ops list, fetches, etc.

    Returns:
      The updated graph and graph_info with the ops/tensors/fetches updated
      according to the imported graph.
    """
    assert isinstance(graph_info.fetches, dict)
    assert isinstance(graph_info.global_step, tf.Variable)
    if not self.forward_only_and_freeze:
      return (graph, graph_info)

    # Get the names of the ops that need to keep during conversion.
    flattened_op_names = list(
        set([
            v.name.split(':')[0]
            for v in nest.flatten(graph_info)
            if v is not None
        ]))
    # Get variables that we don't want to freeze.
    # Only keep unfreezable variables in forward_only_and_freeze mode.
    # TODO(laigd): consider making global_step a constant.
    variables_to_keep = {graph_info.global_step: tf.GraphKeys.GLOBAL_VARIABLES}
    variables_to_keep.update({
        local_variable: tf.GraphKeys.LOCAL_VARIABLES
        for local_variable in self._unfreezable_local_variables(graph)
    })

    variable_initializers = [
        variable.initializer.name for variable in variables_to_keep]
    output_node_names = (
        flattened_op_names +
        # Add variable initializer and read ops to the output list, so
        # convert_variables_to_constants() will keep them.
        variable_initializers +
        [variable.value().op.name for variable in variables_to_keep])
    graphdef = graph.as_graph_def(add_shapes=True)

    # Freeze the graph.
    with graph.as_default():
      with tf.Session(config=create_config_proto(self.params)) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        graphdef = graph_util.convert_variables_to_constants(
            sess,
            graphdef,
            output_node_names,
            variable_names_blacklist=[
                variable.op.name for variable in variables_to_keep
            ])

    # Run TensorRT conversion.
    if self.params.trt_mode:
      # Import here instead of at top, because this will crash if TensorRT is
      # not installed
      from tensorflow.python.compiler.tensorrt import trt_convert  # pylint: disable=g-import-not-at-top
      # Avoid TF-TRT bridge from touching all variable initializer ops and their
      # dependencies, since they can directly be fetched by sess.run()s that
      # initialize the variables.
      # pylint: disable=protected-access
      name_to_input_name, _, _ = graph_util_impl._extract_graph_summary(
          graphdef)
      initializer_subgraph_ops = graph_util_impl._bfs_for_reachable_nodes(
          variable_initializers, name_to_input_name)
      # pylint: enable=protected-access

      graphdef = trt_convert.create_inference_graph(
          graphdef,
          outputs=output_node_names + list(initializer_subgraph_ops),
          max_batch_size=self.model.get_batch_size(),
          max_workspace_size_bytes=self.params.trt_max_workspace_size_bytes,
          precision_mode=self.params.trt_mode)

    # Creates a new graph as the default and import the converted graph back.
    updated_graph = tf.Graph()

    def _get_tensors_or_ops(inputs):
      """Gets the updated tensors or ops from 'updated_graph'."""

      def _get_fn(element):
        if element is None:
          return None
        if ':' in element.name:
          return updated_graph.get_tensor_by_name(element.name)
        return updated_graph.get_operation_by_name(element.name)

      if isinstance(inputs, (list, dict, tuple)):
        return nest.map_structure(_get_fn, inputs)
      else:
        return _get_fn(inputs)

    with updated_graph.as_default():
      importer.import_graph_def(graph_def=graphdef, name='')

      # Update the variables
      for variable in variables_to_keep:
        updated_variable = tf.Variable.from_proto(variable.to_proto())
        tf.add_to_collection(variables_to_keep[variable], updated_variable)
        if variable is graph_info.global_step:
          updated_global_step = updated_variable

    updated_graph_info = GraphInfo(
        input_producer_op=_get_tensors_or_ops(graph_info.input_producer_op),
        enqueue_ops=_get_tensors_or_ops(graph_info.enqueue_ops),
        execution_barrier=_get_tensors_or_ops(graph_info.execution_barrier),
        local_var_init_op_group=_get_tensors_or_ops(
            graph_info.local_var_init_op_group),
        fetches=_get_tensors_or_ops(graph_info.fetches),
        global_step=updated_global_step,
        summary_op=None)
    return (updated_graph, updated_graph_info)