in model-archiver/model_archiver/model_packaging_utils.py [0:0]
def convert_onnx_model(model_path, onnx_file, model_name):
"""
Util to convert onnx model to MXNet model
:param model_name:
:param model_path:
:param onnx_file:
:return:
"""
try:
import mxnet as mx
from mxnet.contrib import onnx as onnx_mxnet
except ImportError:
raise ModelArchiverError("MXNet package is not installed. Run command: pip install mxnet to install it.")
try:
import onnx
except ImportError:
raise ModelArchiverError("Onnx package is not installed. Run command: pip install onnx to install it.")
symbol_file = '%s-symbol.json' % model_name
params_file = '%s-0000.params' % model_name
signature_file = 'signature.json'
# Find input symbol name and shape
try:
model_proto = onnx.load(os.path.join(model_path, onnx_file))
except:
logging.error("Failed to load the %s model. Verify if the model file is valid", onnx_file)
raise
graph = model_proto.graph
_params = set()
for tensor_vals in graph.initializer:
_params.add(tensor_vals.name)
input_data = []
for graph_input in graph.input:
shape = []
if graph_input.name not in _params:
for val in graph_input.type.tensor_type.shape.dim:
shape.append(val.dim_value)
input_data.append((graph_input.name, tuple(shape)))
try:
sym, arg_params, aux_params = onnx_mxnet.import_model(os.path.join(model_path, onnx_file))
# UNION of argument and auxillary parameters
params = dict(arg_params, **aux_params)
except:
logging.error("Failed to import %s file to onnx. Verify if the model file is valid", onnx_file)
raise
try:
# rewrite input data_name correctly
with open(os.path.join(model_path, signature_file), 'r') as f:
data = json.loads(f.read())
data['inputs'][0]['data_name'] = input_data[0][0]
data['inputs'][0]['data_shape'] = [int(i) for i in input_data[0][1]]
with open(os.path.join(model_path, signature_file), 'w') as f:
f.write(json.dumps(data, indent=2))
with open(os.path.join(model_path, symbol_file), 'w') as f:
f.write(sym.tojson())
except:
logging.error("Failed to write the signature or symbol files for %s model", onnx_file)
raise
save_dict = {('arg:%s' % k): v.as_in_context(mx.cpu()) for k, v in params.items()}
mx.nd.save(os.path.join(model_path, params_file), save_dict)
return symbol_file, params_file