def convert_onnx_model()

in model-archiver/model_archiver/model_packaging_utils.py [0:0]


    def convert_onnx_model(model_path, onnx_file, model_name):
        """
        Util to convert onnx model to MXNet model
        :param model_name:
        :param model_path:
        :param onnx_file:
        :return:
        """
        try:
            import mxnet as mx
            from mxnet.contrib import onnx as onnx_mxnet
        except ImportError:
            raise ModelArchiverError("MXNet package is not installed. Run command: pip install mxnet to install it.")

        try:
            import onnx
        except ImportError:
            raise ModelArchiverError("Onnx package is not installed. Run command: pip install onnx to install it.")

        symbol_file = '%s-symbol.json' % model_name
        params_file = '%s-0000.params' % model_name
        signature_file = 'signature.json'
        # Find input symbol name and shape
        try:
            model_proto = onnx.load(os.path.join(model_path, onnx_file))
        except:
            logging.error("Failed to load the %s model. Verify if the model file is valid", onnx_file)
            raise

        graph = model_proto.graph
        _params = set()
        for tensor_vals in graph.initializer:
            _params.add(tensor_vals.name)

        input_data = []
        for graph_input in graph.input:
            shape = []
            if graph_input.name not in _params:
                for val in graph_input.type.tensor_type.shape.dim:
                    shape.append(val.dim_value)
                input_data.append((graph_input.name, tuple(shape)))

        try:
            sym, arg_params, aux_params = onnx_mxnet.import_model(os.path.join(model_path, onnx_file))
            # UNION of argument and auxillary parameters
            params = dict(arg_params, **aux_params)
        except:
            logging.error("Failed to import %s file to onnx. Verify if the model file is valid", onnx_file)
            raise

        try:
            # rewrite input data_name correctly
            with open(os.path.join(model_path, signature_file), 'r') as f:
                data = json.loads(f.read())
                data['inputs'][0]['data_name'] = input_data[0][0]
                data['inputs'][0]['data_shape'] = [int(i) for i in input_data[0][1]]
            with open(os.path.join(model_path, signature_file), 'w') as f:
                f.write(json.dumps(data, indent=2))

            with open(os.path.join(model_path, symbol_file), 'w') as f:
                f.write(sym.tojson())
        except:
            logging.error("Failed to write the signature or symbol files for %s model", onnx_file)
            raise

        save_dict = {('arg:%s' % k): v.as_in_context(mx.cpu()) for k, v in params.items()}
        mx.nd.save(os.path.join(model_path, params_file), save_dict)
        return symbol_file, params_file