in scripts/tf_cnn_benchmarks/benchmark_cnn.py [0:0]
def _preprocess_graph(self, graph, graph_info):
"""Preprocess the graph before executing.
Depending on the params, it runs various preprocessing on the graph,
including freezing, TensorRT conversion, etc.
Args:
graph: the graph to preprocess.
graph_info: the namedtuple returned by _build_graph() which
contains all necessary information to benchmark the graph, including
named tensors/ops list, fetches, etc.
Returns:
The updated graph and graph_info with the ops/tensors/fetches updated
according to the imported graph.
"""
assert isinstance(graph_info.fetches, dict)
assert isinstance(graph_info.global_step, tf.Variable)
if not self.forward_only_and_freeze:
return (graph, graph_info)
# Get the names of the ops that need to keep during conversion.
flattened_op_names = list(
set([
v.name.split(':')[0]
for v in nest.flatten(graph_info)
if v is not None
]))
# Get variables that we don't want to freeze.
# Only keep unfreezable variables in forward_only_and_freeze mode.
# TODO(laigd): consider making global_step a constant.
variables_to_keep = {graph_info.global_step: tf.GraphKeys.GLOBAL_VARIABLES}
variables_to_keep.update({
local_variable: tf.GraphKeys.LOCAL_VARIABLES
for local_variable in self._unfreezable_local_variables(graph)
})
variable_initializers = [
variable.initializer.name for variable in variables_to_keep]
output_node_names = (
flattened_op_names +
# Add variable initializer and read ops to the output list, so
# convert_variables_to_constants() will keep them.
variable_initializers +
[variable.value().op.name for variable in variables_to_keep])
graphdef = graph.as_graph_def(add_shapes=True)
# Freeze the graph.
with graph.as_default():
with tf.Session(config=create_config_proto(self.params)) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
graphdef = graph_util.convert_variables_to_constants(
sess,
graphdef,
output_node_names,
variable_names_blacklist=[
variable.op.name for variable in variables_to_keep
])
# Run TensorRT conversion.
if self.params.trt_mode:
# Import here instead of at top, because this will crash if TensorRT is
# not installed
from tensorflow.python.compiler.tensorrt import trt_convert # pylint: disable=g-import-not-at-top
# Avoid TF-TRT bridge from touching all variable initializer ops and their
# dependencies, since they can directly be fetched by sess.run()s that
# initialize the variables.
# pylint: disable=protected-access
name_to_input_name, _, _ = graph_util_impl._extract_graph_summary(
graphdef)
initializer_subgraph_ops = graph_util_impl._bfs_for_reachable_nodes(
variable_initializers, name_to_input_name)
# pylint: enable=protected-access
graphdef = trt_convert.create_inference_graph(
graphdef,
outputs=output_node_names + list(initializer_subgraph_ops),
max_batch_size=self.model.get_batch_size(),
max_workspace_size_bytes=self.params.trt_max_workspace_size_bytes,
precision_mode=self.params.trt_mode)
# Creates a new graph as the default and import the converted graph back.
updated_graph = tf.Graph()
def _get_tensors_or_ops(inputs):
"""Gets the updated tensors or ops from 'updated_graph'."""
def _get_fn(element):
if element is None:
return None
if ':' in element.name:
return updated_graph.get_tensor_by_name(element.name)
return updated_graph.get_operation_by_name(element.name)
if isinstance(inputs, (list, dict, tuple)):
return nest.map_structure(_get_fn, inputs)
else:
return _get_fn(inputs)
with updated_graph.as_default():
importer.import_graph_def(graph_def=graphdef, name='')
# Update the variables
for variable in variables_to_keep:
updated_variable = tf.Variable.from_proto(variable.to_proto())
tf.add_to_collection(variables_to_keep[variable], updated_variable)
if variable is graph_info.global_step:
updated_global_step = updated_variable
updated_graph_info = GraphInfo(
input_producer_op=_get_tensors_or_ops(graph_info.input_producer_op),
enqueue_ops=_get_tensors_or_ops(graph_info.enqueue_ops),
execution_barrier=_get_tensors_or_ops(graph_info.execution_barrier),
local_var_init_op_group=_get_tensors_or_ops(
graph_info.local_var_init_op_group),
fetches=_get_tensors_or_ops(graph_info.fetches),
global_step=updated_global_step,
summary_op=None)
return (updated_graph, updated_graph_info)