def gen_code()

in mmdnn/conversion/mxnet/mxnet_emitter.py [0:0]


    def gen_code(self, phase):
        self.IR_layer_map = dict()
        self.add_body(0, self.header_code)
        for layer in self.IR_graph.topological_sort:
            self.IR_layer_map[layer] = self.IR_graph.get_node(layer)

        shape = dict()
        for layer in self.IR_graph.topological_sort:
            current_node = self.IR_graph.get_node(layer)
            node_type = current_node.type


            if len(current_node.in_edges) == 0:
                current_node.in_edges.append('data')

            if node_type.lower() in MXNetEmitter.activation_map:
                func = getattr(self, "emit_Activation")
                line = func(current_node, MXNetEmitter.activation_map[node_type.lower()].lower())
                self.add_body(1, line)

            elif hasattr(self, "emit_" + node_type):
                func = getattr(self, "emit_" + node_type)
                line = func(current_node)
                if line != None:
                    self.add_body(1, line)
            else:
                print("MXNet Emitter has not supported operator [%s]." % (node_type))
                self.emit_UNKNOWN(current_node)

            if node_type == "DataInput":
                cur_shape = list()
                first = True
                for dim in current_node.IR_layer.attr["shape"].shape.dim:
                    if dim.size == -1 and first:
                        cur_shape.append(1)
                        print("Detect input layer [{}] using infer batch size, set it as default value [1]".format(current_node.name))
                    else:
                        if dim.size == -1:
                            print("Warning: user should change input size manually")
                        cur_shape.append(dim.size)
                    first = False

                cur_shape.insert(1, cur_shape.pop())
                shape[current_node.name] = ', '.join('%s' % i for i in cur_shape)
                self.input_name_shape = {current_node.name: tuple(cur_shape)}


        if self.weight_loaded:
            fullpath = os.path.abspath(self.output_weights_file)
            dirname = os.path.dirname(fullpath)
            if not os.path.exists(dirname):
                os.makedirs(dirname)
            with open(self.output_weights_file, 'wb') as outfile:
                np.save(outfile, self.output_weights)

        comment = "\n    # if a GPU is available, change mx.cpu() to mx.gpu()"
        # We use the real_name for specifying the input layer in data_names
        # since MXNet API wants the actual name of the layer. On the other
        # hand, the module API wants the last symbol in the symbol chain, so
        # for the output node we need to use the actual python variable name
        # of the last layer (real_variable_name).
        last_line = "{:<15} = mx.mod.Module(symbol = {}, context = mx.cpu(), data_names = ['{}'])".format(
            "model",
            ', '.join([self.IR_graph.get_node(name).real_variable_name for name in self.IR_graph.output_layers if self.IR_graph.get_node(name).type !='Pack' and self.IR_graph.get_node(name).type != 'Shape']),
            ', '.join([self.IR_graph.get_node(name).real_name for name in self.IR_graph.input_layers if self.IR_graph.get_node(name).type != 'Const']))

        self.add_body(1, comment)
        self.add_body(1, last_line)
        self.add_body(1, "return model")


        self.add_body(0, "")
        for code in self.layers_codes.values():
            self.add_body(0, code)

        weight_code = ""
        if not self.weight_loaded:
            weight_code += "# emitter does not detect any import weights, you may generate weights file manually\n"

        weight_code += self.gen_weight_code(shape, phase)

        main_code = "if __name__ == '__main__':\n    model = RefactorModel()\n"
        if self.weight_loaded:
            main_code += "    # remember to adjust params path\n    model = deploy_weight(model, '{}')\n".format(self.output_weights_file)

        if phase == 'train':
            train_code = """def train(model):
    import logging
    logging.getLogger().setLevel(logging.DEBUG)
    model.fit(train_iter, # train data
            eval_data = val_iter, # validation data
            optimizer = 'sgd', # Defaults to 'sgd'
            optimizer_params = {'learning_rate':0.01}, # use fixed learning rate
            eval_metric = 'acc', # report accuracy during training, other possible predefined metrics are: 'ce', 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'
            batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches
            num_epoch = 10) # train for at most 10 dataset passes\n\n
"""
            code = self.body_code + weight_code + train_code + main_code
        else:
            test_code = """from collections import namedtuple
Batch = namedtuple('Batch', ['data'])


def get_image(url, show=False):
    import cv2
    # download and show the image
    fname = mx.test_utils.download(url)
    img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_BGR2RGB)
    if img is None:
        return None
    if show:
        import matplotlib.pyplot as plt
        plt.imshow(img)
        plt.axis('off')
    # convert into format (batch, RGB, width, height)
    img = cv2.resize(img, (224, 224))
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)
    img = img[np.newaxis, :]
    return img


def predict(model, labels, url):
    # to show the image, change the argument show into True
    img = get_image(url, show = False)
    # compute the predict probabilities
    model.forward(Batch([mx.nd.array(img)]))
    prob = model.get_outputs()[0].asnumpy()
    # print the top-5
    prob = np.squeeze(prob)
    a = np.argsort(prob)[::-1]
    for i in a[0:5]:
        print('prbability = %f, class = %s' %(prob[i], labels[i]))\n\n
"""

            main_code += """
    # # call function predict
    # with open('synset.txt', 'r') as f:
    #     labels = [l.rstrip() for l in f]
    # predict(model, labels, 'http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg')
"""

            code = self.body_code + weight_code + test_code + main_code

        return code