in mmdnn/conversion/tensorflow/tensorflow_frozenparser.py [0:0]
def __init__(self, frozen_file, inputshape, in_nodes, dest_nodes):
if LooseVersion(tensorflow.__version__) < LooseVersion('1.8.0'):
raise ImportError(
'Your TensorFlow version %s is outdated. '
'MMdnn requires tensorflow>=1.8.0' % tensorflow.__version__)
super(TensorflowParser2, self).__init__()
self.weight_loaded = True
# load model files into TensorFlow graph
with open(frozen_file, 'rb') as f:
serialized = f.read()
tensorflow.reset_default_graph()
original_gdef = tensorflow.GraphDef()
original_gdef.ParseFromString(serialized)
in_type_list = {}
for n in original_gdef.node:
if n.name in in_nodes:
in_type_list[n.name] = n.attr['dtype'].type
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import gfile
original_gdef = strip_unused_lib.strip_unused(
input_graph_def = original_gdef,
input_node_names = in_nodes,
output_node_names = dest_nodes,
placeholder_type_enum = dtypes.float32.as_datatype_enum)
# Save it to an output file
tempdir = tempfile.mkdtemp()
frozen_model_file = os.path.join(tempdir, 'frozen.pb')
with gfile.GFile(frozen_model_file, "wb") as f:
f.write(original_gdef.SerializeToString())
with open(frozen_model_file, 'rb') as f:
serialized = f.read()
shutil.rmtree(tempdir)
tensorflow.reset_default_graph()
model = tensorflow.GraphDef()
model.ParseFromString(serialized)
output_shape_map = dict()
input_shape_map = dict()
dtype = tensorflow.float32
with tensorflow.Graph().as_default() as g:
input_map = {}
for i in range(len(inputshape)):
dtype = TensorflowParser2.tf_dtype_map[in_type_list[in_nodes[i]]]
if in_type_list[in_nodes[i]] in (0, 1, 2):
x = tensorflow.placeholder(dtype, shape=[None] + inputshape[i])
elif in_type_list[in_nodes[i]] in (3, 4, 5, 6, 7):
x = tensorflow.placeholder(dtype, shape=inputshape[i])
elif in_type_list[in_nodes[i]] == 10:
x = tensorflow.placeholder(dtype)
else:
raise NotImplementedError
input_map[in_nodes[i] + ':0'] = x
tensorflow.import_graph_def(model, name='', input_map=input_map)
with tensorflow.Session(graph = g) as sess:
tempdir = tempfile.mkdtemp()
meta_graph_def = tensorflow.train.export_meta_graph(filename=os.path.join(tempdir, 'my-model.meta'))
model = meta_graph_def.graph_def
shutil.rmtree((tempdir))
self.tf_graph = TensorflowGraph(model)
self.tf_graph.build()