def __init__()

in mmdnn/conversion/tensorflow/tensorflow_parser.py [0:0]


    def __init__(self, meta_file, checkpoint_file, dest_nodes, inputShape = None, in_nodes = None):
        super(TensorflowParser, self).__init__()

        # load model files into TensorFlow graph
        if meta_file:
            model = TensorflowParser._load_meta(meta_file)

        if checkpoint_file:
            self.ckpt_data = TensorflowParser._load_weights(checkpoint_file)
            self.weight_loaded = True

        # extract subgraph using in_nodes and dest_nodes
        if in_nodes != None and inputShape != None:
            from tensorflow.python.tools import strip_unused_lib
            from tensorflow.python.framework import dtypes
            from tensorflow.python.platform import gfile
            model = strip_unused_lib.strip_unused(
                    input_graph_def = model,
                    input_node_names = in_nodes,
                    output_node_names = dest_nodes,
                    placeholder_type_enum = dtypes.float32.as_datatype_enum)

            input_list = [None]
            for i in range(len(inputShape)):
                input_list.append(tensorflow.Dimension(inputShape[i]))
            tensor_input = tensorflow.TensorShape(input_list)
            # Build network graph
            self.tf_graph = TensorflowGraph(model)
            for node in self.tf_graph.model.node:
                if node.name in in_nodes:
                    node.attr['shape'].shape.CopyFrom(tensor_input.as_proto())
                    node.attr['_output_shapes'].list.shape.pop()  #unknown_rank pop
                    node.attr['_output_shapes'].list.shape.extend([tensor_input.as_proto()])

        # extract subgraph using dest_nodes
        elif dest_nodes != None:
            from tensorflow.python.framework.graph_util import extract_sub_graph
            model = extract_sub_graph(model, dest_nodes)

        #  Get input node name
        if not in_nodes:
            in_nodes = []
            for node in model.node:
                if node.op == 'Placeholder':
                    in_nodes.append(node.name)

        # Graph Transform
        transforms = ["fold_constants(ignore_errors=true)"]
        transformed_graph_def = TransformGraph(model, in_nodes,
                                                dest_nodes, transforms)
        in_type_list = {}
        in_shape_list = {}

        for n in transformed_graph_def.node:
            if n.name in in_nodes:
                in_type_list[n.name] = n.attr['dtype'].type
                in_node_shape = n.attr['shape'].shape
                in_node_shape_str = self._shapeToStr(in_node_shape)
                in_shape_list[n.name] = in_node_shape_str

        dtype = tensorflow.float32
        with tensorflow.Graph().as_default() as g:
            input_map = {}
            for in_node in in_nodes:
                if in_type_list[in_node] == 1 or in_type_list[in_node] == 0:
                    dtype = tensorflow.float32

                elif in_type_list[in_node] == 3:
                    dtype = tensorflow.int32

                elif in_type_list[in_node] == 10:
                    dtype = tensorflow.bool
                
                x = tensorflow.placeholder(dtype, shape = in_shape_list[in_node])
                input_map[in_node] = x

            tensorflow.import_graph_def(transformed_graph_def, name='', input_map=input_map)

        with tensorflow.Session(graph = g) as sess:
            tempdir = tempfile.mkdtemp()
            meta_graph_def = tensorflow.train.export_meta_graph(filename=os.path.join(tempdir, 'my-model.meta'))
            model = meta_graph_def.graph_def
            shutil.rmtree(tempdir)

        self.tf_graph = TensorflowGraph(model)
        self.tf_graph.build()

        process_graph(self.tf_graph, self.ckpt_data)