def __init__()

in mmdnn/conversion/tensorflow/tensorflow_frozenparser.py [0:0]


    def __init__(self, frozen_file, inputshape, in_nodes, dest_nodes):
        if LooseVersion(tensorflow.__version__) < LooseVersion('1.8.0'):
            raise ImportError(
                'Your TensorFlow version %s is outdated. '
                'MMdnn requires tensorflow>=1.8.0' % tensorflow.__version__)

        super(TensorflowParser2, self).__init__()

        self.weight_loaded = True
        # load model files into TensorFlow graph
        with open(frozen_file, 'rb') as f:
            serialized = f.read()
        tensorflow.reset_default_graph()
        original_gdef = tensorflow.GraphDef()

        original_gdef.ParseFromString(serialized)

        in_type_list = {}
        for n in original_gdef.node:
            if n.name in in_nodes:
                in_type_list[n.name] = n.attr['dtype'].type

        from tensorflow.python.tools import strip_unused_lib
        from tensorflow.python.framework import dtypes
        from tensorflow.python.platform import gfile
        original_gdef = strip_unused_lib.strip_unused(
                input_graph_def = original_gdef,
                input_node_names = in_nodes,
                output_node_names = dest_nodes,
                placeholder_type_enum = dtypes.float32.as_datatype_enum)
        # Save it to an output file
        tempdir = tempfile.mkdtemp()
        frozen_model_file = os.path.join(tempdir, 'frozen.pb')
        with gfile.GFile(frozen_model_file, "wb") as f:
            f.write(original_gdef.SerializeToString())
        with open(frozen_model_file, 'rb') as f:
            serialized = f.read()
        shutil.rmtree(tempdir)

        tensorflow.reset_default_graph()
        model = tensorflow.GraphDef()
        model.ParseFromString(serialized)

        output_shape_map = dict()
        input_shape_map = dict()
        dtype = tensorflow.float32

        with tensorflow.Graph().as_default() as g:
            input_map = {}
            for i in range(len(inputshape)):
                dtype = TensorflowParser2.tf_dtype_map[in_type_list[in_nodes[i]]]
                if in_type_list[in_nodes[i]] in (0, 1, 2):
                    x = tensorflow.placeholder(dtype, shape=[None] + inputshape[i])
                elif in_type_list[in_nodes[i]] in (3, 4, 5, 6, 7):
                    x = tensorflow.placeholder(dtype, shape=inputshape[i])
                elif in_type_list[in_nodes[i]] == 10:
                    x = tensorflow.placeholder(dtype)
                else:
                    raise NotImplementedError

                input_map[in_nodes[i] + ':0'] = x

            tensorflow.import_graph_def(model, name='', input_map=input_map)

        with tensorflow.Session(graph = g) as sess:

            tempdir = tempfile.mkdtemp()
            meta_graph_def = tensorflow.train.export_meta_graph(filename=os.path.join(tempdir, 'my-model.meta'))
            model = meta_graph_def.graph_def
            shutil.rmtree((tempdir))

        self.tf_graph = TensorflowGraph(model)
        self.tf_graph.build()