in mmdnn/conversion/tensorflow/tensorflow_parser.py [0:0]
def __init__(self, meta_file, checkpoint_file, dest_nodes, inputShape = None, in_nodes = None):
super(TensorflowParser, self).__init__()
# load model files into TensorFlow graph
if meta_file:
model = TensorflowParser._load_meta(meta_file)
if checkpoint_file:
self.ckpt_data = TensorflowParser._load_weights(checkpoint_file)
self.weight_loaded = True
# extract subgraph using in_nodes and dest_nodes
if in_nodes != None and inputShape != None:
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import gfile
model = strip_unused_lib.strip_unused(
input_graph_def = model,
input_node_names = in_nodes,
output_node_names = dest_nodes,
placeholder_type_enum = dtypes.float32.as_datatype_enum)
input_list = [None]
for i in range(len(inputShape)):
input_list.append(tensorflow.Dimension(inputShape[i]))
tensor_input = tensorflow.TensorShape(input_list)
# Build network graph
self.tf_graph = TensorflowGraph(model)
for node in self.tf_graph.model.node:
if node.name in in_nodes:
node.attr['shape'].shape.CopyFrom(tensor_input.as_proto())
node.attr['_output_shapes'].list.shape.pop() #unknown_rank pop
node.attr['_output_shapes'].list.shape.extend([tensor_input.as_proto()])
# extract subgraph using dest_nodes
elif dest_nodes != None:
from tensorflow.python.framework.graph_util import extract_sub_graph
model = extract_sub_graph(model, dest_nodes)
# Get input node name
if not in_nodes:
in_nodes = []
for node in model.node:
if node.op == 'Placeholder':
in_nodes.append(node.name)
# Graph Transform
transforms = ["fold_constants(ignore_errors=true)"]
transformed_graph_def = TransformGraph(model, in_nodes,
dest_nodes, transforms)
in_type_list = {}
in_shape_list = {}
for n in transformed_graph_def.node:
if n.name in in_nodes:
in_type_list[n.name] = n.attr['dtype'].type
in_node_shape = n.attr['shape'].shape
in_node_shape_str = self._shapeToStr(in_node_shape)
in_shape_list[n.name] = in_node_shape_str
dtype = tensorflow.float32
with tensorflow.Graph().as_default() as g:
input_map = {}
for in_node in in_nodes:
if in_type_list[in_node] == 1 or in_type_list[in_node] == 0:
dtype = tensorflow.float32
elif in_type_list[in_node] == 3:
dtype = tensorflow.int32
elif in_type_list[in_node] == 10:
dtype = tensorflow.bool
x = tensorflow.placeholder(dtype, shape = in_shape_list[in_node])
input_map[in_node] = x
tensorflow.import_graph_def(transformed_graph_def, name='', input_map=input_map)
with tensorflow.Session(graph = g) as sess:
tempdir = tempfile.mkdtemp()
meta_graph_def = tensorflow.train.export_meta_graph(filename=os.path.join(tempdir, 'my-model.meta'))
model = meta_graph_def.graph_def
shutil.rmtree(tempdir)
self.tf_graph = TensorflowGraph(model)
self.tf_graph.build()
process_graph(self.tf_graph, self.ckpt_data)